Train

GPUs pulling all the weight

Supervised Training with Linear Probing

PatchTFT Masked Autoregression SS Prediction


source

PatchTFTSleepStage


def PatchTFTSleepStage(
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of gpus to use
    linear_probing_head, # model head to train
    preloaded_model, # loaded pretrained model to use for linear probing
    metrics:dict={}, # metrics to calculate
    fine_tune:bool=False, # indicator to finetune encoder model or perform linear probing and freeze encoder weights
    loss_fxn:str='CrossEntropy', # loss function to use, can be CrossEntropy or FocalLoss
    class_weights:NoneType=None, # weights of classes to use in CE loss fxn
    gamma:float=2.0, # for focal loss
    label_smoothing:int=0, # label smoothing for cross entropy loss
    y_padding_mask:int=-100, # padded value that was added to target and indice to ignore when computing loss
    epochs:int=100, # number of epochs for one_cycle_scheduler
    weight_decay:float=0.0, # weight decay for Adam optimizer
    use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
    final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
    optimizer_type:str='Adam', # optimizer to use, 'Adam' or 'AdamW'
    scheduler_type:str='OneCycle', # scheduler to use, 'OneCycle' or 'CosineAnnealingWarmRestarts'
    scheduler_kwargs:dict={}, # kwargs for the scheduler
):

Hooks to be used in LightningModule.

PatchTFT Single Outcome Prediction


source

PatchTFTSingleOutcomeLightning


def PatchTFTSingleOutcomeLightning(
    linear_probing_head, # model head to linear probe/train
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of GPUs
    preloaded_model, # loaded pretrained model to use for linear probing
    metrics:dict={}, # name:function for metrics to log
    fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
    class_weights:NoneType=None, # weights of classes to use in CE loss fxn
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
    weight_decay:float=0.0, # weight decay for Adam optimizer
    use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
    final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
    scheduler_kwargs:dict={}, # kwargs for the scheduler
    transforms:NoneType=None, # transforms to apply to the data
    mixup_callback:NoneType=None, # mixup callback to apply to the data
    regression:bool=False, # whether the outcome is regression or classification
    loss_func:NoneType=None, # loss function to use
):

Hooks to be used in LightningModule.


source

PatchTFTSingleOutcomeHypnogramLightning


def PatchTFTSingleOutcomeHypnogramLightning(
    linear_probing_head, # model head to linear probe/train
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of GPUs
    preloaded_model, # loaded pretrained model to use for linear probing
    patch_len, d_model, metrics:dict={}, # name:function for metrics to log
    hypnogram_index:int=7, signal_only:bool=False, zero_hypnogram:bool=False, hypnogram_only:bool=False,
    fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
    class_weights:NoneType=None, # weights of classes to use in CE loss fxn
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
    weight_decay:float=0.0, # weight decay for Adam optimizer
    use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
    final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
    scheduler_kwargs:dict={}, # kwargs for the scheduler
    transforms:NoneType=None, # transforms to apply to the data
    mixup_callback:NoneType=None, # mixup callback to apply to the data
):

Hooks to be used in LightningModule.

hypnogram_index = 7
import torch
hypnogram_index = torch.tensor([hypnogram_index])
x_index = torch.tensor(list(range(hypnogram_index)))

Full Supervised Training

PatchTFT Sleep Stage Supervised Training


source

PatchTFTSupervised


def PatchTFTSupervised(
    encoder_kwargs, # args to initialize the PatchTFT encoder
    linear_probing_head, # model head to linear probe/train
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of GPUs
    num_nodes:int=1, # number of nodes
    metrics:dict={}, # name:function for metrics to log
    class_weights:NoneType=None, # weights of classes to use in CE loss fxn
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
    weight_decay:float=0.0, # weight decay for Adam optimizer
    use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
    final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
    scheduler_kwargs:dict={}, # kwargs for the scheduler
    transforms:NoneType=None, # transforms to apply to the data
    mixup_callback:NoneType=None, # mixup callback to apply to the data
):

Hooks to be used in LightningModule.

Miscellaneous


source

LinearProbingCallback


def LinearProbingCallback(
    lp_model, lp_epochs, lp_train_dataloader, lp_val_dataloader, lp_class_weights, learning_rate,
    weight_decay:float=0.0, eval_frequency:int=5
):

Abstract base class used to build new callbacks.

Subclass this class and override any of the relevant hooks

class EncoderTestFreezingWeights(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Linear(750,512)
    def forward(self, x):
        x = self.encoder(x)
        return x

class DecoderTest(pl.LightningModule):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder
        self.encoder.freeze()
        self.decoder = nn.Linear(512,750)
        
    def training_step(self, batch, batch_idx):
        print("Training step:")
        print(f"Encoder training mode: {self.encoder.training}")
        print(f"Decoder training mode: {self.decoder.training}")
        print(f"Model training mode: {self.training}")
        return torch.tensor(0.0, requires_grad=True)
        
    def validation_step(self, batch, batch_idx):
        print("Validation step:")
        print(f"Encoder training mode: {self.encoder.training}")
        print(f"Decoder training mode: {self.decoder.training}")
        print(f"Model training mode: {self.training}")
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)
        
    def train_dataloader(self):
        # Create a simple dummy dataset
        return DataLoader(TensorDataset(torch.randn(10, 750), torch.randn(10, 750)), batch_size=2)
        
    def val_dataloader(self):
        # Create a simple dummy dataset
        return DataLoader(TensorDataset(torch.randn(10, 750), torch.randn(10, 750)), batch_size=2)

# Create a trainer and properly run validation
encoder = EncoderTestFreezingWeights()
decoder = DecoderTest(encoder)
trainer = pl.Trainer(fast_dev_run=True)
trainer.fit(decoder)