Train

GPUs pulling all the weight

PatchTFT Single Outcome Prediction


source

PatchTFTSingleOutcomeLightning


def PatchTFTSingleOutcomeLightning(
    linear_probing_head, # model head to linear probe/train
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of GPUs
    preloaded_model, # loaded pretrained model to use for linear probing
    metrics:dict={}, # name:function for metrics to log
    fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
    class_weights:NoneType=None, # weights of classes to use in CE loss fxn
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
    weight_decay:float=0.0, # weight decay for Adam optimizer
    use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
    final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
    scheduler_kwargs:dict={}, # kwargs for the scheduler
    transforms:NoneType=None, # transforms to apply to the data
    mixup_callback:NoneType=None, # mixup callback to apply to the data
):

Hooks to be used in LightningModule.