Survival

I’m barely alive

source

hazards_to_survival_torch


def hazards_to_survival_torch(
    preds
):

Call self as a function.


source

hazards_to_survival


def hazards_to_survival(
    preds
):

Call self as a function.


source

discretize_time_convert_logits_to_estimates


def discretize_time_convert_logits_to_estimates(
    preds, new_time_days, cuts_days
):

Call self as a function.


source

make_cuts_days


def make_cuts_days(
    min_year:float, max_year:float, bin_years:float
):

Build right-edges of bins (in DAYS) for the discrete-time hazard model.

Returns: cuts_days (np.ndarray): shape [J], right edges in days.


source

discretize_time


def discretize_time(
    time_days, bin_years:float=0.25, min_year:float=0.0, max_year:float=10.0
):

Discretize a tensor of observed times (in days) into bin indices for discrete-time survival.

Args: time_days (Tensor): observed times in days (event or censoring). bin_years (float): width of bins in years (e.g., 0.25 = quarter-year). min_year (float): lower bound of the window in years. max_year (float): upper bound of the window in years.

Returns: bin_idx (Tensor): indices in [0, n_bins-1] for each input time

# Yearly bins from 0–10y
cuts_days = make_cuts_days(0, 10, 1)
# edges_years = [1,2,3,...,10]
# bins = (0,1], (1,2], ..., (9,10]

# Times
t = torch.tensor([100., 365., 3700.])  # ~0.27y, 1y, ~10.1y
idx = discretize_time(t, bin_years=1, min_year=0, max_year=10)
# -> [0, 0, 9]  (last one clipped at 10y)
print(cuts_days), print(idx)
[ 365.25  730.5  1095.75 1461.   1826.25 2191.5  2556.75 2922.   3287.25
 3652.5 ]
tensor([0, 0, 9])
(None, None)

Cox w Momentum


source

PatchTFTSurvivalDemo


def PatchTFTSurvivalDemo(
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of gpus to use
    linear_probing_head, # model head to linear probe/train
    preloaded_model, # loaded pretrained model to use for linear probing
    evaluate_risk_years:list=[1, 5, 10],
    discretize_time:bool=True, # indicator to discretize time for discrete-time survival
    discrete_time_bins:float=0.25, # function to discretize time for discrete-time survival
    time_range_years:tuple=(0, 15), # events after this time are censored
    fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
    weight_decay:float=0.001, # weight decay for Adam optimizer
    final_weight_decay:float=1e-06, # weight decay for final layer
    use_weight_decay_scheduler:bool=False,
    demographic_embeddings:dict={'age_bin_idx': 9, 'bmi_bin_idx': 7, 'gender': 2}, age_mlp_hidden_size:int=32,
    demographics_only:bool=False, scheduler_kwargs:dict={}, transforms:NoneType=None
):

Hooks to be used in LightningModule.


source

SurvivalDemo


def SurvivalDemo(
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of gpus to use
    evaluate_risk_years:list=[1, 5, 10],
    discretize_time:bool=True, # indicator to discretize time for discrete-time survival
    discrete_time_bins:float=0.25, # function to discretize time for discrete-time survival
    time_range_years:tuple=(0, 15), # events after this time are censored
    fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', optimizer_type:str='AdamW',
    weight_decay:float=0.001, # weight decay for Adam optimizer
    final_weight_decay:float=1e-06, # weight decay for final layer
    use_weight_decay_scheduler:bool=False,
    demographic_embeddings:dict={'age_bin_idx': 9, 'bmi_bin_idx': 7, 'gender': 2}, age_mlp_hidden_size:int=32,
    scheduler_kwargs:dict={}, transforms:NoneType=None
):

Hooks to be used in LightningModule.


source

PatchTFTSurvivalMomentum


def PatchTFTSurvivalMomentum(
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of gpus to use
    linear_probing_head, # model head to linear probe/train
    preloaded_model, # loaded pretrained model to use for linear probing
    loss_func, evaluate_risk_years:list=[1, 5, 10],
    discretize_time:bool=False, # indicator to discretize time for discrete-time survival
    discrete_time_bins:float=0.25, # function to discretize time for discrete-time survival
    time_range_years:tuple=(0, 15), # events after this time are censored
    fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', momentum_steps:int=4, momentum_rate:float=0.999, optimizer_type:str='AdamW',
    weight_decay:float=0.001, # weight decay for Adam optimizer
    final_weight_decay:float=1e-06, # weight decay for final layer
    use_weight_decay_scheduler:bool=False, demographic_predictor:bool=False, n_demographics:int=1,
    demographic_predictor_hidden:int=64, demographic_adversarial_weight:float=1.0, scheduler_kwargs:dict={}
):

Hooks to be used in LightningModule.


source

PatchTFTHypnogramSurvivalMomentum


def PatchTFTHypnogramSurvivalMomentum(
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, linear_probing_head, # model head to linear probe/train
    preloaded_model, # loaded pretrained model to use for linear probing
    loss_func, d_model, patch_len, hypnogram_index:int=7, hypnogram_only:bool=False, signal_only:bool=False,
    zero_hypnogram:bool=False, auc_years:list=[1, 5, 10],
    fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', momentum_steps:int=4, momentum_rate:float=0.999, optimizer_type:str='AdamW',
    weight_decay:float=0.001, # weight decay for Adam optimizer
    final_weight_decay:float=1e-06, # weight decay for final layer
    use_weight_decay_scheduler:bool=False,
    create_zero_channel_mask:bool=False, # create a zero channel mask for the encoder when the a data channel is all zeros
    scheduler_kwargs:dict={}
):

Hooks to be used in LightningModule.


source

HypnogramEncoder


def HypnogramEncoder(
    patch_len, d_model, num_categories:int=5, y_padding_mask:int=5
):

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

hyp = torch.randint(0, 5, (5, 101))
#print(hyp.max(), hyp.min(), hyp.shape)

#hyp[2, 10:30] = 5
hypnogram_encoder = HypnogramEncoder(patch_len=5, num_categories=6, d_model=512, y_padding_mask=5)#, y_padding_mask=5)
hypnogram_enc = hypnogram_encoder(hyp)
#hypnogram_enc2 = hypnogram_encoder(hyp)
#print(hypnogram_enc.shape, hypnogram_enc2.shape)
#torch.cat([hypnogram_enc, hypnogram_enc2], dim=1).shape
torch.Size([5, 1, 101])
torch.Size([5, 21, 1, 5])
torch.Size([5, 21, 1, 5, 128])
torch.Size([5, 21, 1, 128])
torch.Size([5, 21, 1, 512])
torch.Size([5, 1, 512, 21])

Weibull


source

PatchTFTSurvivalWeibull


def PatchTFTSurvivalWeibull(
    learning_rate, # desired learning rate, initial learning rate in if one_cycle_scheduler
    train_size, # the training data size (for one_cycle_scheduler=True)
    batch_size, # the batch size (for one_cycle_scheduler=True)
    n_gpus, # number of gpus
    linear_probing_head, # model head to linear probe/train
    preloaded_model, # loaded pretrained model to use for linear probing
    loss_func, auc_years:list=[1, 5, 10],
    fine_tune:bool=False, # indicator to fine tune encoder or freeze encoder weights and perform linear probing
    max_lr:float=0.01, # maximum learning rate for one_cycle_scheduler
    epochs:int=100, # number of epochs for one_cycle_scheduler
    scheduler_type:str='OneCycle', optimizer_type:str='Adam',
    weight_decay:float=0.0, # weight decay for Adam optimizer
    torch_model_name:str='model', # name of the pytorch model within the lightning model module, this is to remove layers (for example lightning_model.pytorch_model.head = nn.Identity())
    remove_pretrain_layers:list=['head'], # layers within the lightning model or lightning model.pytorch_model to remove
    create_zero_channel_mask:bool=False, # create a zero channel mask for the encoder when the a data channel is all zeros
):

Hooks to be used in LightningModule.