class EncoderTestFreezingWeights(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Linear(750,512)
def forward(self, x):
= self.encoder(x)
x return x
class DecoderTest(pl.LightningModule):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
self.encoder.freeze()
def training_step(self, batch, batch_idx):
print(list(self.encoder.parameters()))
print(self.encoder.training)
# x = self.encoder(x)
# return x
= EncoderTestFreezingWeights()
encoder = DecoderTest(encoder)
decoder
0,0) decoder.training_step(
Train
Pew pew pew
Self Supervised PatchTFT Lightning
PatchTFTSimpleLightning
PatchTFTSimpleLightning (learning_rate, train_size, batch_size, channels, metrics, precalculate_onebatch_stft_stats=False, use_sequence_padding_mask=False, loss_func='mse', max_lr=0.01, weight_decay=0.0, epochs=100, one_cycle_scheduler=True, optimizer_type='Adam', scheduler_type='OneCycle', cross_attention=False, use_mask=False, patch_continuity_loss=0, huber_delta=None, **patchmeup_kwargs)
Hooks to be used in LightningModule.
Type | Default | Details | |
---|---|---|---|
learning_rate | |||
train_size | |||
batch_size | |||
channels | |||
metrics | |||
precalculate_onebatch_stft_stats | bool | False | |
use_sequence_padding_mask | bool | False | |
loss_func | str | mse | |
max_lr | float | 0.01 | |
weight_decay | float | 0.0 | |
epochs | int | 100 | |
one_cycle_scheduler | bool | True | |
optimizer_type | str | Adam | |
scheduler_type | str | OneCycle | |
cross_attention | bool | False | not implemented |
use_mask | bool | False | |
patch_continuity_loss | int | 0 | indicator and ratio of patch continuity loss function, which examines ensures patches dont have large discontinuities |
huber_delta | NoneType | None | huber loss delta, not used otherwise |
patchmeup_kwargs | VAR_KEYWORD |
Supervised Training with Linear Probing
PatchTFT Masked Autoregression SS Prediction
PatchTFTSleepStage
PatchTFTSleepStage (learning_rate, train_size, batch_size, linear_probing_head, metrics=[], fine_tune=False, loss_fxn='CrossEntropy', class_weights=None, gamma=2.0, label_smoothing=0, use_sequence_padding_mask=False, y_padding_mask=-100, max_lr=0.01, epochs=100, one_cycle_scheduler=True, weight_decay=0.0, pretrained_encoder_path=None, optimizer_type='Adam', scheduler_type='OneCycle', preloaded_model=None, torch_model_name='model', remove_pretrain_layers=['head', 'mask'], return_softmax=True)
Hooks to be used in LightningModule.
Type | Default | Details | |
---|---|---|---|
learning_rate | desired learning rate, initial learning rate in if one_cycle_scheduler | ||
train_size | the training data size (for one_cycle_scheduler=True) | ||
batch_size | the batch size (for one_cycle_scheduler=True) | ||
linear_probing_head | model head to train | ||
metrics | list | [] | metrics to calculate |
fine_tune | bool | False | indicator to finetune encoder model or perform linear probing and freeze encoder weights |
loss_fxn | str | CrossEntropy | loss function to use, can be CrossEntropy or FocalLoss |
class_weights | NoneType | None | weights of classes to use in CE loss fxn |
gamma | float | 2.0 | for focal loss |
label_smoothing | int | 0 | label smoothing for cross entropy loss |
use_sequence_padding_mask | bool | False | indicator to use the sequence padding mask when training/in the loss fxn |
y_padding_mask | int | -100 | padded value that was added to target and indice to ignore when computing loss |
max_lr | float | 0.01 | maximum learning rate for one_cycle_scheduler |
epochs | int | 100 | number of epochs for one_cycle_scheduler |
one_cycle_scheduler | bool | True | indicator to use a one cycle scheduler to vary the learning rate |
weight_decay | float | 0.0 | weight decay for Adam optimizer |
pretrained_encoder_path | NoneType | None | path of the pretrained model to use for linear probing |
optimizer_type | str | Adam | optimizer to use, ‘Adam’ or ‘AdamW’ |
scheduler_type | str | OneCycle | scheduler to use, ‘OneCycle’ or ‘CosineAnnealingWarmRestarts’ |
preloaded_model | NoneType | None | loaded pretrained model to use for linear probing |
torch_model_name | str | model | name of the pytorch model within the lightning model module, this is to remove layers (for example lightning_model.pytorch_model.head = nn.Identity()) |
remove_pretrain_layers | list | [‘head’, ‘mask’] | layers within the lightning model or lightning model.pytorch_model to remove |
return_softmax | bool | True | indicator to return softmax probabilities in forward and predict_step |