Train
GPUs pulling all the weight
PatchTFT Single Outcome Prediction
PatchTFTSingleOutcomeLightning
def PatchTFTSingleOutcomeLightning(
linear_probing_head, # model head to linear probe/train
learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
train_size, # the training data size (for one_cycle_scheduler=True)
batch_size, # the batch size (for one_cycle_scheduler=True)
n_gpus, # number of GPUs
preloaded_model, # loaded pretrained model to use for linear probing
metrics:dict={}, # name:function for metrics to log
fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
class_weights:NoneType=None, # weights of classes to use in CE loss fxn
epochs:int=100, # number of epochs for one_cycle_scheduler
scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
weight_decay:float=0.0, # weight decay for Adam optimizer
use_weight_decay_scheduler:bool=False, # use a weight decay scheduler
final_weight_decay:float=0.01, # final weight decay for the weight decay scheduler
scheduler_kwargs:dict={}, # kwargs for the scheduler
transforms:NoneType=None, # transforms to apply to the data
mixup_callback:NoneType=None, # mixup callback to apply to the data
):
Hooks to be used in LightningModule.