Adapted from https://github.com/havakv/pycox/blob/3eccdd7fd9844a060f50fdcc315659f33a2d2dc1/pycox/models/loss.py#L18 Negative log-likelihood of the discrete time hazard parametrized model LogisticHazard [1].
Arguments: phi {torch.tensor} – Estimates in (-inf, inf), where hazard = sigmoid(phi). idx_durations {torch.tensor} – Event times represented as indices. events {torch.tensor} – Indicator of event (1.) or censoring (0.). Same length as ‘idx_durations’. reduction {string} – How to reduce the loss. ‘none’: No reduction. ‘mean’: Mean of tensor. ’sum: sum.
Returns: torch.tensor – The negative log-likelihood.
References: [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. https://arxiv.org/pdf/1910.06724.pdf
def multilabel_cox_loss( x, event, time, n_labels, ties_method:str='efron', reduction:str='mean'):
Multilabel Cox Loss Function x: [bs x n_labels] “log hazard” event: [bs x n_labels] - 1 for event, 0 for censored time: [bs x n_labels] - time of event or time censored
x = torch.randn(2, 14)event = torch.randint(0, 2, (2,14))time = torch.randint(1, 300, (2,14))cox_ph_loss(x, event, time), multilabel_cox_loss(x, event, time, 14)
/opt/miniconda3/envs/timeflies/lib/python3.12/site-packages/torchsurv/loss/cox.py:124: UserWarning: No events OR single sample. Returning zero loss for the batch
warnings.warn("No events OR single sample. Returning zero loss for the batch")
adapted from tsai, weighted multiclass focal loss https://github.com/timeseriesAI/tsai/blob/bdff96cc8c4c8ea55bc20d7cffd6a72e402f4cb2/tsai/losses.py#L116C1-L140C20
adapted from torchsurv: https://github.com/Novartis/torchsurv/blob/main/src/torchsurv/loss/momentum.py Survival framework to momentum update learning to decouple batch size during model training. Two networks are concurrently trained, an online network and a target network. The online network outputs batches are concanetaed and used by the target network, so it virtually increase its batchsize.
The target network (k)is updated using an exponential momentum average (EMA) using parameters from the online network (q). The online network is trained using a memory bank of previously computed log hazards, but only tracking loss from current batch.
Kullback-Leibler Divergence Loss with masking for ignore_index. Handles soft labels with ignore_index marked as -100.
Args: logits: [bs x n_classes x pred_labels] - model predictions targets: [bs x n_classes x soft_labels] - soft labels, with ignore_index positions marked as 0
x = torch.randn(4,5,10)y = torch.softmax(torch.randn(4,5,10), dim=1)KLDivLoss()(x,y)
tensor(0.6321)
batch_size =5n_classes =5max_len =480# Create sequences of different lengthsseq_lens = torch.randint(50, max_len, (batch_size,))# Create input tensors with different sequence lengthsx_list = [torch.randn(n_classes, length) for length in seq_lens]y_list = [torch.softmax(torch.randn(n_classes, length), dim=0) for length in seq_lens]x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)y_nested = torch.nested.as_nested_tensor(y_list, layout=torch.jagged)criterion = KLDivLoss()criterion(x_nested, y_nested)
tensor(0.6608)
# import torch# import torch.nn as nn# import torch.nn.functional as F# import torch.distributed.nn# from util import misc# def compute_cross_entropy(p, q):# q = F.log_softmax(q, dim=-1)# loss = torch.sum(p * q, dim=-1)# return - loss.mean()# def stablize_logits(logits):# logits_max, _ = torch.max(logits, dim=-1, keepdim=True)# logits = logits - logits_max.detach()# return logits# @torch.no_grad()# def concat_all_gather(tensor):# """# Performs all_gather operation on the provided tensors.# *** Warning ***: torch.distributed.all_gather has no gradient.# """# tensors_gather = [torch.ones_like(tensor)# for _ in range(torch.distributed.get_world_size())]# torch.distributed.all_gather(tensors_gather, tensor, async_op=False)# output = torch.cat(tensors_gather, dim=0)# return output# class MultiPosConLoss(nn.Module):# """# Multi-Positive Contrastive Loss: https://arxiv.org/pdf/2306.00984.pdf# """# def __init__(self, temperature=0.1):# super(MultiPosConLoss, self).__init__()# self.temperature = temperature# self.logits_mask = None# self.mask = None# self.last_local_batch_size = None# def forward(self, x, y):# feats = x # feats shape: [B, D]# labels = y # labels shape: [B]# feats = F.normalize(feats, dim=-1, p=2)# local_batch_size = feats.size(0)# all_feats = torch.cat(torch.distributed.nn.all_gather(feats), dim=0)# all_labels = concat_all_gather(labels) # no gradient gather# # compute the mask based on labels# if local_batch_size != self.last_local_batch_size:# mask = torch.eq(labels.view(-1, 1),# all_labels.contiguous().view(1, -1)).float().to(device)# self.logits_mask = torch.scatter(# torch.ones_like(mask),# 1,# torch.arange(mask.shape[0]).view(-1, 1).to(device) +# local_batch_size * misc.get_rank(),# 0# )# self.last_local_batch_size = local_batch_size# self.mask = mask * self.logits_mask# mask = self.mask# # compute logits# logits = torch.matmul(feats, all_feats.T) / self.temperature# logits = logits - (1 - self.logits_mask) * 1e9# # optional: minus the largest logit to stablize logits# logits = stablize_logits(logits)# # compute ground-truth distribution# p = mask / mask.sum(1, keepdim=True).clamp(min=1.0)# loss = compute_cross_entropy(p, logits)# return loss