Discretize a tensor of observed times (in days) into bin indices for discrete-time survival.
Args: time_days (Tensor): observed times in days (event or censoring). bin_years (float): width of bins in years (e.g., 0.25 = quarter-year). min_year (float): lower bound of the window in years. max_year (float): upper bound of the window in years.
Returns: bin_idx (Tensor): indices in [0, n_bins-1] for each input time
def PatchTFTSurvivalDemo( learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler train_size, # the training data size (for one_cycle_scheduler=True) batch_size, # the batch size (for one_cycle_scheduler=True) n_gpus, # number of gpus to use linear_probing_head, # model head to linear probe/train preloaded_model, # loaded pretrained model to use for linear probing evaluate_risk_years:list=[1, 5, 10], discretize_time:bool=True, # indicator to discretize time for discrete-time survival discrete_time_bins:float=0.25, # function to discretize time for discrete-time survival time_range_years:tuple=(0, 15), # events after this time are censored fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing epochs:int=100, # number of epochs for one_cycle_scheduler scheduler_type:str='OneCycle', optimizer_type:str='AdamW', weight_decay:float=0.001, # weight decay for Adam optimizer final_weight_decay:float=1e-06, # weight decay for final layer use_weight_decay_scheduler:bool=False, demographic_embeddings:dict={'age_bin_idx': 9, 'bmi_bin_idx': 7, 'gender': 2}, age_mlp_hidden_size:int=32, demographics_only:bool=False, scheduler_kwargs:dict={}, transforms:NoneType=None):
def SurvivalDemo( learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler train_size, # the training data size (for one_cycle_scheduler=True) batch_size, # the batch size (for one_cycle_scheduler=True) n_gpus, # number of gpus to use evaluate_risk_years:list=[1, 5, 10], discretize_time:bool=True, # indicator to discretize time for discrete-time survival discrete_time_bins:float=0.25, # function to discretize time for discrete-time survival time_range_years:tuple=(0, 15), # events after this time are censored fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing epochs:int=100, # number of epochs for one_cycle_scheduler scheduler_type:str='OneCycle', optimizer_type:str='AdamW', weight_decay:float=0.001, # weight decay for Adam optimizer final_weight_decay:float=1e-06, # weight decay for final layer use_weight_decay_scheduler:bool=False, demographic_embeddings:dict={'age_bin_idx': 9, 'bmi_bin_idx': 7, 'gender': 2}, age_mlp_hidden_size:int=32, scheduler_kwargs:dict={}, transforms:NoneType=None):
def PatchTFTSurvivalMomentum( learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler train_size, # the training data size (for one_cycle_scheduler=True) batch_size, # the batch size (for one_cycle_scheduler=True) n_gpus, # number of gpus to use linear_probing_head, # model head to linear probe/train preloaded_model, # loaded pretrained model to use for linear probing loss_func, evaluate_risk_years:list=[1, 5, 10], discretize_time:bool=False, # indicator to discretize time for discrete-time survival discrete_time_bins:float=0.25, # function to discretize time for discrete-time survival time_range_years:tuple=(0, 15), # events after this time are censored fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing epochs:int=100, # number of epochs for one_cycle_scheduler scheduler_type:str='OneCycle', momentum_steps:int=4, momentum_rate:float=0.999, optimizer_type:str='AdamW', weight_decay:float=0.001, # weight decay for Adam optimizer final_weight_decay:float=1e-06, # weight decay for final layer use_weight_decay_scheduler:bool=False, demographic_predictor:bool=False, n_demographics:int=1, demographic_predictor_hidden:int=64, demographic_adversarial_weight:float=1.0, scheduler_kwargs:dict={}):
def PatchTFTHypnogramSurvivalMomentum( learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler train_size, # the training data size (for one_cycle_scheduler=True) batch_size, # the batch size (for one_cycle_scheduler=True) n_gpus, linear_probing_head, # model head to linear probe/train preloaded_model, # loaded pretrained model to use for linear probing loss_func, d_model, patch_len, hypnogram_index:int=7, hypnogram_only:bool=False, signal_only:bool=False, zero_hypnogram:bool=False, auc_years:list=[1, 5, 10], fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing epochs:int=100, # number of epochs for one_cycle_scheduler scheduler_type:str='OneCycle', momentum_steps:int=4, momentum_rate:float=0.999, optimizer_type:str='AdamW', weight_decay:float=0.001, # weight decay for Adam optimizer final_weight_decay:float=1e-06, # weight decay for final layer use_weight_decay_scheduler:bool=False, create_zero_channel_mask:bool=False, # create a zero channel mask for the encoder when the a data channel is all zeros scheduler_kwargs:dict={}):
def PatchTFTSurvivalWeibull( learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler train_size, # the training data size (for one_cycle_scheduler=True) batch_size, # the batch size (for one_cycle_scheduler=True) n_gpus, # number of gpus linear_probing_head, # model head to linear probe/train preloaded_model, # loaded pretrained model to use for linear probing loss_func, auc_years:list=[1, 5, 10], fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing max_lr:float=0.01, # maximum learning rate for one_cycle_scheduler epochs:int=100, # number of epochs for one_cycle_scheduler scheduler_type:str='OneCycle', optimizer_type:str='Adam', weight_decay:float=0.0, # weight decay for Adam optimizer torch_model_name:str='model', # name of the pytorch model within the lightning model module, this is to remove layers (for example lightning_model.pytorch_model.head = nn.Identity()) remove_pretrain_layers:list=['head'], # layers within the lightning model or lightning model.pytorch_model to remove create_zero_channel_mask:bool=False, # create a zero channel mask for the encoder when the a data channel is all zeros):