Train

Pew pew pew

Self Supervised PatchTFT Lightning


source

PatchTFTSimpleLightning


def PatchTFTSimpleLightning(
    learning_rate, train_size, batch_size, channels, metrics, precalculate_onebatch_stft_stats:bool=False,
    use_sequence_padding_mask:bool=False, loss_func:str='mse', max_lr:float=0.01, weight_decay:float=0.0,
    epochs:int=100, one_cycle_scheduler:bool=True, optimizer_type:str='Adam', scheduler_type:str='OneCycle',
    cross_attention:bool=False, # not implemented
    use_mask:bool=False,
    patch_continuity_loss:int=0, # indicator and ratio of patch continuity loss function, which examines ensures patches dont have large discontinuities
    huber_delta:NoneType=None, # huber loss delta, not used otherwise
    patchmeup_kwargs:VAR_KEYWORD
):

Hooks to be used in LightningModule.

class EncoderTestFreezingWeights(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Linear(750,512)
    def forward(self, x):
        x = self.encoder(x)
        return x

class DecoderTest(pl.LightningModule):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.encoder.freeze()
    def training_step(self, batch, batch_idx):
        print(list(self.encoder.parameters()))
        print(self.encoder.training)
        # x = self.encoder(x)
        # return x

encoder = EncoderTestFreezingWeights()
decoder = DecoderTest(encoder)

decoder.training_step(0,0)

Supervised Training with Linear Probing

PatchTFT Masked Autoregression SS Prediction


source

PatchTFTSleepStage


def PatchTFTSleepStage(
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    linear_probing_head, # model head to train
    metrics:list=[], # metrics to calculate
    fine_tune:bool=False, # indicator to finetune encoder model or perform linear probing and freeze encoder weights
    loss_fxn:str='CrossEntropy', # loss function to use, can be CrossEntropy or FocalLoss
    class_weights:NoneType=None, # weights of classes to use in CE loss fxn
    gamma:float=2.0, # for focal loss
    label_smoothing:int=0, # label smoothing for cross entropy loss
    use_sequence_padding_mask:bool=False, # indicator to use the sequence padding mask when training/in the loss fxn
    y_padding_mask:int=-100, # padded value that was added to target and indice to ignore when computing loss
    max_lr:float=0.01, # maximum learning rate for one_cycle_scheduler
    epochs:int=100, # number of epochs for one_cycle_scheduler
    one_cycle_scheduler:bool=True, # indicator to use a one cycle scheduler to vary the learning rate
    weight_decay:float=0.0, # weight decay for Adam optimizer
    pretrained_encoder_path:NoneType=None, # path of the pretrained model to use for linear probing
    optimizer_type:str='Adam', # optimizer to use, 'Adam' or 'AdamW'
    scheduler_type:str='OneCycle', # scheduler to use, 'OneCycle' or 'CosineAnnealingWarmRestarts'
    preloaded_model:NoneType=None, # loaded pretrained model to use for linear probing
    torch_model_name:str='model', # name of the pytorch model within the lightning model module, this is to remove layers (for example lightning_model.pytorch_model.head = nn.Identity())
    remove_pretrain_layers:list=['head', 'mask'], # layers within the lightning model or lightning model.pytorch_model to remove
    return_softmax:bool=True, # indicator to return softmax probabilities in forward and predict_step
):

Hooks to be used in LightningModule.