Train
Some Specific Trainers
PatchTFTSleepStage
def PatchTFTSleepStage(
learning_rate, train_size, batch_size, n_gpus, linear_probing_head, preloaded_model, metrics:dict={},
fine_tune:bool=False, loss_fxn:str='CrossEntropy', class_weights:NoneType=None, gamma:float=2.0,
label_smoothing:int=0, y_padding_mask:int=-100, epochs:int=100, weight_decay:float=0.0,
use_weight_decay_scheduler:bool=False, final_weight_decay:float=0.01, optimizer_type:str='Adam',
scheduler_type:str='OneCycle', scheduler_kwargs:dict={}
):
Hooks to be used in LightningModule.
PatchTFTSingleOutcomeLightning
def PatchTFTSingleOutcomeLightning(
linear_probing_head, learning_rate, train_size, batch_size, n_gpus, preloaded_model, metrics:dict={},
fine_tune:bool=False, class_weights:NoneType=None, epochs:int=100, scheduler_type:str='OneCycle',
optimizer_type:str='AdamW', weight_decay:float=0.0, use_weight_decay_scheduler:bool=False,
final_weight_decay:float=0.01, scheduler_kwargs:dict={}, transforms:NoneType=None, mixup_callback:NoneType=None,
regression:bool=False, loss_func:NoneType=None
):
Hooks to be used in LightningModule.