hypnogram_index = 7
import torch
hypnogram_index = torch.tensor([hypnogram_index])
x_index = torch.tensor(list(range(hypnogram_index)))Train
GPUs pulling all the weight
Supervised Training with Linear Probing
PatchTFT Masked Autoregression SS Prediction
PatchTFTSleepStage
def PatchTFTSleepStage(
learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
train_size, # the training data size (for one_cycle_scheduler=True)
batch_size, # the batch size (for one_cycle_scheduler=True)
n_gpus, # number of gpus to use
linear_probing_head, # model head to train
preloaded_model, # loaded pretrained model to use for linear probing
metrics:dict={}, # metrics to calculate
fine_tune:bool=False, # indicator to finetune encoder model or perform linear probing and freeze encoder weights
loss_fxn:str='CrossEntropy', # loss function to use, can be CrossEntropy or FocalLoss
class_weights:NoneType=None, # weights of classes to use in CE loss fxn
gamma:float=2.0, # for focal loss
label_smoothing:int=0, # label smoothing for cross entropy loss
y_padding_mask:int=-100, # padded value that was added to target and indice to ignore when computing loss
epochs:int=100, # number of epochs for one_cycle_scheduler
weight_decay:float=0.0, # weight decay for Adam optimizer
use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
optimizer_type:str='Adam', # optimizer to use, 'Adam' or 'AdamW'
scheduler_type:str='OneCycle', # scheduler to use, 'OneCycle' or 'CosineAnnealingWarmRestarts'
scheduler_kwargs:dict={}, # kwargs for the scheduler
):
Hooks to be used in LightningModule.
PatchTFT Single Outcome Prediction
PatchTFTSingleOutcomeLightning
def PatchTFTSingleOutcomeLightning(
linear_probing_head, # model head to linear probe/train
learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
train_size, # the training data size (for one_cycle_scheduler=True)
batch_size, # the batch size (for one_cycle_scheduler=True)
n_gpus, # number of GPUs
preloaded_model, # loaded pretrained model to use for linear probing
metrics:dict={}, # name:function for metrics to log
fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
class_weights:NoneType=None, # weights of classes to use in CE loss fxn
epochs:int=100, # number of epochs for one_cycle_scheduler
scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
weight_decay:float=0.0, # weight decay for Adam optimizer
use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
scheduler_kwargs:dict={}, # kwargs for the scheduler
transforms:NoneType=None, # transforms to apply to the data
mixup_callback:NoneType=None, # mixup callback to apply to the data
regression:bool=False, # whether the outcome is regression or classification
loss_func:NoneType=None, # loss function to use
):
Hooks to be used in LightningModule.
PatchTFTSingleOutcomeHypnogramLightning
def PatchTFTSingleOutcomeHypnogramLightning(
linear_probing_head, # model head to linear probe/train
learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
train_size, # the training data size (for one_cycle_scheduler=True)
batch_size, # the batch size (for one_cycle_scheduler=True)
n_gpus, # number of GPUs
preloaded_model, # loaded pretrained model to use for linear probing
patch_len, d_model, metrics:dict={}, # name:function for metrics to log
hypnogram_index:int=7, signal_only:bool=False, zero_hypnogram:bool=False, hypnogram_only:bool=False,
fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
class_weights:NoneType=None, # weights of classes to use in CE loss fxn
epochs:int=100, # number of epochs for one_cycle_scheduler
scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
weight_decay:float=0.0, # weight decay for Adam optimizer
use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
scheduler_kwargs:dict={}, # kwargs for the scheduler
transforms:NoneType=None, # transforms to apply to the data
mixup_callback:NoneType=None, # mixup callback to apply to the data
):
Hooks to be used in LightningModule.
Full Supervised Training
PatchTFT Sleep Stage Supervised Training
PatchTFTSupervised
def PatchTFTSupervised(
encoder_kwargs, # args to initialize the PatchTFT encoder
linear_probing_head, # model head to linear probe/train
learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
train_size, # the training data size (for one_cycle_scheduler=True)
batch_size, # the batch size (for one_cycle_scheduler=True)
n_gpus, # number of GPUs
num_nodes:int=1, # number of nodes
metrics:dict={}, # name:function for metrics to log
class_weights:NoneType=None, # weights of classes to use in CE loss fxn
epochs:int=100, # number of epochs for one_cycle_scheduler
scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
weight_decay:float=0.0, # weight decay for Adam optimizer
use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
scheduler_kwargs:dict={}, # kwargs for the scheduler
transforms:NoneType=None, # transforms to apply to the data
mixup_callback:NoneType=None, # mixup callback to apply to the data
):
Hooks to be used in LightningModule.
Miscellaneous
LinearProbingCallback
def LinearProbingCallback(
lp_model, lp_epochs, lp_train_dataloader, lp_val_dataloader, lp_class_weights, learning_rate,
weight_decay:float=0.0, eval_frequency:int=5
):
Abstract base class used to build new callbacks.
Subclass this class and override any of the relevant hooks
class EncoderTestFreezingWeights(pl.LightningModule):
def __init__(self):
super().__init__()
self.encoder = nn.Linear(750,512)
def forward(self, x):
x = self.encoder(x)
return x
class DecoderTest(pl.LightningModule):
def __init__(self, encoder):
super().__init__()
self.encoder = encoder
self.encoder.freeze()
self.decoder = nn.Linear(512,750)
def training_step(self, batch, batch_idx):
print("Training step:")
print(f"Encoder training mode: {self.encoder.training}")
print(f"Decoder training mode: {self.decoder.training}")
print(f"Model training mode: {self.training}")
return torch.tensor(0.0, requires_grad=True)
def validation_step(self, batch, batch_idx):
print("Validation step:")
print(f"Encoder training mode: {self.encoder.training}")
print(f"Decoder training mode: {self.decoder.training}")
print(f"Model training mode: {self.training}")
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
# Create a simple dummy dataset
return DataLoader(TensorDataset(torch.randn(10, 750), torch.randn(10, 750)), batch_size=2)
def val_dataloader(self):
# Create a simple dummy dataset
return DataLoader(TensorDataset(torch.randn(10, 750), torch.randn(10, 750)), batch_size=2)
# Create a trainer and properly run validation
encoder = EncoderTestFreezingWeights()
decoder = DecoderTest(encoder)
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(decoder)