Train

Pew pew pew

Self Supervised PatchTFT Lightning


source

PatchTFTSimpleLightning

 PatchTFTSimpleLightning (learning_rate, train_size, batch_size, channels,
                          metrics, precalculate_onebatch_stft_stats=False,
                          use_sequence_padding_mask=False,
                          loss_func='mse', max_lr=0.01, weight_decay=0.0,
                          epochs=100, one_cycle_scheduler=True,
                          optimizer_type='Adam',
                          scheduler_type='OneCycle',
                          cross_attention=False, use_mask=False,
                          patch_continuity_loss=0, huber_delta=None,
                          **patchmeup_kwargs)

Hooks to be used in LightningModule.

Type Default Details
learning_rate
train_size
batch_size
channels
metrics
precalculate_onebatch_stft_stats bool False
use_sequence_padding_mask bool False
loss_func str mse
max_lr float 0.01
weight_decay float 0.0
epochs int 100
one_cycle_scheduler bool True
optimizer_type str Adam
scheduler_type str OneCycle
cross_attention bool False not implemented
use_mask bool False
patch_continuity_loss int 0 indicator and ratio of patch continuity loss function, which examines ensures patches dont have large discontinuities
huber_delta NoneType None huber loss delta, not used otherwise
patchmeup_kwargs VAR_KEYWORD
class EncoderTestFreezingWeights(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Linear(750,512)
    def forward(self, x):
        x = self.encoder(x)
        return x

class DecoderTest(pl.LightningModule):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.encoder.freeze()
    def training_step(self, batch, batch_idx):
        print(list(self.encoder.parameters()))
        print(self.encoder.training)
        # x = self.encoder(x)
        # return x

encoder = EncoderTestFreezingWeights()
decoder = DecoderTest(encoder)

decoder.training_step(0,0)

Supervised Training with Linear Probing

PatchTFT Masked Autoregression SS Prediction


source

PatchTFTSleepStage

 PatchTFTSleepStage (learning_rate, train_size, batch_size,
                     linear_probing_head, metrics=[], fine_tune=False,
                     loss_fxn='CrossEntropy', class_weights=None,
                     gamma=2.0, label_smoothing=0,
                     use_sequence_padding_mask=False, y_padding_mask=-100,
                     max_lr=0.01, epochs=100, one_cycle_scheduler=True,
                     weight_decay=0.0, pretrained_encoder_path=None,
                     optimizer_type='Adam', scheduler_type='OneCycle',
                     preloaded_model=None, torch_model_name='model',
                     remove_pretrain_layers=['head', 'mask'],
                     return_softmax=True)

Hooks to be used in LightningModule.

Type Default Details
learning_rate desired learning rate, initial learning rate in if one_cycle_scheduler
train_size the training data size (for one_cycle_scheduler=True)
batch_size the batch size (for one_cycle_scheduler=True)
linear_probing_head model head to train
metrics list [] metrics to calculate
fine_tune bool False indicator to finetune encoder model or perform linear probing and freeze encoder weights
loss_fxn str CrossEntropy loss function to use, can be CrossEntropy or FocalLoss
class_weights NoneType None weights of classes to use in CE loss fxn
gamma float 2.0 for focal loss
label_smoothing int 0 label smoothing for cross entropy loss
use_sequence_padding_mask bool False indicator to use the sequence padding mask when training/in the loss fxn
y_padding_mask int -100 padded value that was added to target and indice to ignore when computing loss
max_lr float 0.01 maximum learning rate for one_cycle_scheduler
epochs int 100 number of epochs for one_cycle_scheduler
one_cycle_scheduler bool True indicator to use a one cycle scheduler to vary the learning rate
weight_decay float 0.0 weight decay for Adam optimizer
pretrained_encoder_path NoneType None path of the pretrained model to use for linear probing
optimizer_type str Adam optimizer to use, ‘Adam’ or ‘AdamW’
scheduler_type str OneCycle scheduler to use, ‘OneCycle’ or ‘CosineAnnealingWarmRestarts’
preloaded_model NoneType None loaded pretrained model to use for linear probing
torch_model_name str model name of the pytorch model within the lightning model module, this is to remove layers (for example lightning_model.pytorch_model.head = nn.Identity())
remove_pretrain_layers list [‘head’, ‘mask’] layers within the lightning model or lightning model.pytorch_model to remove
return_softmax bool True indicator to return softmax probabilities in forward and predict_step