class EncoderTestFreezingWeights(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Linear(750,512)
def forward(self, x):
x = self.encoder(x)
return x
class DecoderTest(pl.LightningModule):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
self.encoder.freeze()
def training_step(self, batch, batch_idx):
print(list(self.encoder.parameters()))
print(self.encoder.training)
# x = self.encoder(x)
# return x
encoder = EncoderTestFreezingWeights()
decoder = DecoderTest(encoder)
decoder.training_step(0,0)Train
Pew pew pew
Self Supervised PatchTFT Lightning
PatchTFTSimpleLightning
def PatchTFTSimpleLightning(
learning_rate, train_size, batch_size, channels, metrics, precalculate_onebatch_stft_stats:bool=False,
use_sequence_padding_mask:bool=False, loss_func:str='mse', max_lr:float=0.01, weight_decay:float=0.0,
epochs:int=100, one_cycle_scheduler:bool=True, optimizer_type:str='Adam', scheduler_type:str='OneCycle',
cross_attention:bool=False, # not implemented
use_mask:bool=False,
patch_continuity_loss:int=0, # indicator and ratio of patch continuity loss function, which examines ensures patches dont have large discontinuities
huber_delta:NoneType=None, # huber loss delta, not used otherwise
patchmeup_kwargs:VAR_KEYWORD
):
Hooks to be used in LightningModule.
Supervised Training with Linear Probing
PatchTFT Masked Autoregression SS Prediction
PatchTFTSleepStage
def PatchTFTSleepStage(
learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
train_size, # the training data size (for one_cycle_scheduler=True)
batch_size, # the batch size (for one_cycle_scheduler=True)
linear_probing_head, # model head to train
metrics:list=[], # metrics to calculate
fine_tune:bool=False, # indicator to finetune encoder model or perform linear probing and freeze encoder weights
loss_fxn:str='CrossEntropy', # loss function to use, can be CrossEntropy or FocalLoss
class_weights:NoneType=None, # weights of classes to use in CE loss fxn
gamma:float=2.0, # for focal loss
label_smoothing:int=0, # label smoothing for cross entropy loss
use_sequence_padding_mask:bool=False, # indicator to use the sequence padding mask when training/in the loss fxn
y_padding_mask:int=-100, # padded value that was added to target and indice to ignore when computing loss
max_lr:float=0.01, # maximum learning rate for one_cycle_scheduler
epochs:int=100, # number of epochs for one_cycle_scheduler
one_cycle_scheduler:bool=True, # indicator to use a one cycle scheduler to vary the learning rate
weight_decay:float=0.0, # weight decay for Adam optimizer
pretrained_encoder_path:NoneType=None, # path of the pretrained model to use for linear probing
optimizer_type:str='Adam', # optimizer to use, 'Adam' or 'AdamW'
scheduler_type:str='OneCycle', # scheduler to use, 'OneCycle' or 'CosineAnnealingWarmRestarts'
preloaded_model:NoneType=None, # loaded pretrained model to use for linear probing
torch_model_name:str='model', # name of the pytorch model within the lightning model module, this is to remove layers (for example lightning_model.pytorch_model.head = nn.Identity())
remove_pretrain_layers:list=['head', 'mask'], # layers within the lightning model or lightning model.pytorch_model to remove
return_softmax:bool=True, # indicator to return softmax probabilities in forward and predict_step
):
Hooks to be used in LightningModule.