Loss functions

Im lost too. Phds are dumb.

source

NTXentLoss


def NTXentLoss(
    temperature:float=0.2
):

Normalized temperature scaled cross entropy loss. This is far more efficient than the above. No polynomial expansion though.


source

NTXentLoss_poly


def NTXentLoss_poly(
    batch_size, temperature:float=0.2, polynomial_expanison:bool=True
):

Normalized temperature scaled cross entropy loss.

x1, x2 = torch.randn(2,4,100,100), torch.randn(2,4,100,100)

l = NTXentLoss(temperature=0.5)
l(x1.flatten(start_dim=-3),x2.flatten(start_dim=-3)), l(x1.permute(0,1,3,2).flatten(start_dim=-3), x2.permute(0,1,3,2).flatten(start_dim=-3))
(tensor(1.0870), tensor(1.0870))
batch_size = 2
n_vars = 7
max_len = 480
patch_len = 100

# Create sequences of different lengths
seq_lens = torch.randint(50, max_len, (batch_size,))

# Create input tensors with different sequence lengths
x_list = [torch.randn(length, n_vars, patch_len) for length in seq_lens]
x_list2 = [torch.randn(length, n_vars, patch_len) for length in seq_lens]
x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)
x_nested2 = torch.nested.as_nested_tensor(x_list2, layout=torch.jagged)

loss = NTXentLoss(temperature=0.5)
loss(x_nested, x_nested2)
tensor(1.0981)

source

mse_variance_loss


def mse_variance_loss(
    preds, target, representations, alpha:float=0.2
):

preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len] representations: [bs x nvars x d_model x num_patch]


source

smoothl1_loss


def smoothl1_loss(
    preds, target
):

Call self as a function.


source

huber_loss


def huber_loss(
    preds, target, delta:int=1
):

preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len]


source

cosine_similarity_loss


def cosine_similarity_loss(
    preds, target
):

preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len]


source

mape


def mape(
    preds, target
):

Call self as a function.


source

mae_loss


def mae_loss(
    preds, target
):

preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len]


source

mse_loss


def mse_loss(
    preds, target
):

preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len]

x = torch.randn(4,300, 7, 100)
y = torch.randn(4,300, 7, 100)

huber_loss(x,y)
tensor(0.7208)
batch_size = 2
n_vars = 7
max_len = 480
patch_len = 100

# Create sequences of different lengths
seq_lens = torch.randint(50, max_len, (batch_size,))

# Create input tensors with different sequence lengths
x_list = [torch.randn(length, n_vars, patch_len) for length in seq_lens]
x_list2 = [torch.randn(length, n_vars, patch_len) for length in seq_lens]
x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)
x_nested2 = torch.nested.as_nested_tensor(x_list2, layout=torch.jagged)

smoothl1_loss(x_nested, x_nested2)
tensor(0.7184)

source

patch_continuity_loss


def patch_continuity_loss(
    preds
):

preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len]

x = torch.randn(2,2, 2, 10)

patch_continuity_loss(x)
tensor(1.8017)

source

demographic_adversarial_loss


def demographic_adversarial_loss(
    demo_predictions, demo_targets
):

Call self as a function.


source

nll_logistic_hazard


def nll_logistic_hazard(
    phi, events, idx_durations, reduction:str='mean'
):

Adapted from https://github.com/havakv/pycox/blob/3eccdd7fd9844a060f50fdcc315659f33a2d2dc1/pycox/models/loss.py#L18 Negative log-likelihood of the discrete time hazard parametrized model LogisticHazard [1].

Arguments: phi {torch.tensor} – Estimates in (-inf, inf), where hazard = sigmoid(phi). idx_durations {torch.tensor} – Event times represented as indices. events {torch.tensor} – Indicator of event (1.) or censoring (0.). Same length as ‘idx_durations’. reduction {string} – How to reduce the loss. ‘none’: No reduction. ‘mean’: Mean of tensor. ’sum: sum.

Returns: torch.tensor – The negative log-likelihood.

References: [1] Håvard Kvamme and Ørnulf Borgan. Continuous and Discrete-Time Survival Prediction with Neural Networks. arXiv preprint arXiv:1910.06724, 2019. https://arxiv.org/pdf/1910.06724.pdf


source

multilabel_cox_loss


def multilabel_cox_loss(
    x, event, time, n_labels, ties_method:str='efron', reduction:str='mean'
):

Multilabel Cox Loss Function x: [bs x n_labels] “log hazard” event: [bs x n_labels] - 1 for event, 0 for censored time: [bs x n_labels] - time of event or time censored


source

cox_ph_loss


def cox_ph_loss(
    x, event, time
):

Call self as a function.

x = torch.randn(2, 14)
event = torch.randint(0, 2, (2,14))
time = torch.randint(1, 300, (2,14))

cox_ph_loss(x, event, time), multilabel_cox_loss(x, event, time, 14)
/opt/miniconda3/envs/timeflies/lib/python3.12/site-packages/torchsurv/loss/cox.py:124: UserWarning: No events OR single sample. Returning zero loss for the batch
  warnings.warn("No events OR single sample. Returning zero loss for the batch")
(tensor(0.3138), tensor([3.1433], grad_fn=<AddBackward0>))
x = torch.randn(2, 20).float()
event = torch.randint(0, 2, (2,)).float()
time_idx = torch.randint(0, 18, (2,))

events = event.view(-1, 1)
idx_durations = time_idx.view(-1, 1)
y_bce = torch.zeros_like(x).scatter(1, idx_durations, events)
bce = F.binary_cross_entropy_with_logits(x, y_bce, reduction='none')
loss = bce.cumsum(1).gather(1, idx_durations).view(-1)

nll_logistic_hazard(x, event, time_idx)
tensor(7.4561)
x = torch.randn(4, 15).float()
event = torch.rand(4,2)
time_idx_1 = torch.randint(0, 15, (4,))
time_idx_2 = torch.randint(0, 15, (4,))
time_idx = torch.stack([time_idx_1, time_idx_2], dim=1)

events_1 = event[:, 0].view(-1, 1)
idx_durations_1 = time_idx[:, 0].view(-1, 1)
y_bce_1 = torch.zeros_like(x).scatter(1, idx_durations_1, events_1)
bce_1 = F.binary_cross_entropy_with_logits(x, y_bce_1, reduction='none')
loss_1 = bce_1.cumsum(1).gather(1, idx_durations_1).view(-1)

events_2 = event[:, 1].view(-1, 1)
idx_durations_2 = time_idx[:, 1].view(-1, 1)
y_bce_2 = torch.zeros_like(x).scatter(1, idx_durations_2, events_2)
bce_2 = F.binary_cross_entropy_with_logits(x, y_bce_2, reduction='none')
loss_2 = bce_2.cumsum(1).gather(1, idx_durations_2).view(-1)

# Sum the losses (soft labels already contain the mixture weights)
loss = loss_1 + loss_2
#from torchsurv.metrics.auc import Auc
n = 10

time = torch.randint(low=5, high=250, size=(n,), device='cpu').float()

event = torch.randint(low=0, high=2, size=(n,), device='cpu').bool()

estimate = torch.randn((n,), device='cpu')
eval_times = torch.tensor([30.], device='cpu')
auc = Auc(checks=False)

auc(estimate, event, time, auc_type = "cumulative", new_time=eval_times)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[18], line 11
      9 estimate = torch.randn((n,), device='cpu')
     10 eval_times = torch.tensor([30.], device='cpu')
---> 11 auc = Auc(checks=False)
     13 auc(estimate, event, time, auc_type = "cumulative", new_time=eval_times)

NameError: name 'Auc' is not defined

source

CrossEntropyLoss


def CrossEntropyLoss(
    ignore_index:int=-100, reduction:str='mean', weight:NoneType=None, label_smoothing:int=0, soft_labels:bool=False
):

Cross entropy loss with ignore_index.

criterion = CrossEntropyLoss(ignore_index=0, weight=torch.tensor([1,1,1,1,1]), reduction='mean', label_smoothing=0.)
batch_size = 10

n_patch = 721
n_class = 5
#m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, n_class, n_patch)
target = torch.randint(0, n_class, size=(batch_size, n_patch))
criterion(logits, target), nn.CrossEntropyLoss(ignore_index=0, reduction='mean')(logits, target)
(tensor(1.9717), tensor(1.9717))
# soft labels
criterion = CrossEntropyLoss(ignore_index=100, weight=torch.tensor([1,1,1,1,1]), reduction='mean', label_smoothing=0., soft_labels=True)
logits = torch.randn(batch_size, n_class, n_patch)
target = torch.rand(batch_size, n_class, n_patch)
criterion(logits, target)
tensor(49.1424)
batch_size = 2
n_classes = 5
max_len = 480

# Create sequences of different lengths
seq_lens = torch.randint(50, max_len, (batch_size,))

# Create input tensors with different sequence lengths
x_list = [torch.randn(n_classes, length) for length in seq_lens]
y_list = [torch.randint(0, n_classes, size=(length,)) for length in seq_lens]
x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)
y_nested = torch.nested.as_nested_tensor(y_list, layout=torch.jagged)

criterion = CrossEntropyLoss(ignore_index=-100, weight=torch.tensor([1,1,1,1,1]), reduction='mean', label_smoothing=0., soft_labels=False)
criterion(x_nested, y_nested)
tensor(2.0479)

source

FocalLoss


def FocalLoss(
    weight:NoneType=None, gamma:float=2.0, reduction:str='mean', ignore_index:int=-100
):

adapted from tsai, weighted multiclass focal loss https://github.com/timeseriesAI/tsai/blob/bdff96cc8c4c8ea55bc20d7cffd6a72e402f4cb2/tsai/losses.py#L116C1-L140C20

criterion = FocalLoss(gamma=0.7, weight=None, ignore_index=0)
batch_size = 10

n_patch = 721
n_class = 2
#m = torch.nn.Softmax(dim=-1)
logits = torch.randn(batch_size, n_class, n_patch)
target = torch.randint(0, n_class, size=(batch_size, n_patch))
criterion(logits, target.float())
tensor(0.7082)
# soft labels
criterion = FocalLoss(gamma=0.7, weight=torch.tensor([1,1,1,1,1]), ignore_index=0)
logits = torch.randn(batch_size, n_class, n_patch)
target = torch.rand(batch_size, n_class, n_patch)
criterion(logits, target)
tensor(2.8824)
batch_size = 2
n_classes = 5
max_len = 480

# Create sequences of different lengths
seq_lens = torch.randint(50, max_len, (batch_size,))

# Create input tensors with different sequence lengths
x_list = [torch.randn(n_classes, length) for length in seq_lens]
y_list = [torch.randint(0, n_classes, size=(length,)) for length in seq_lens]
y_list[0][0] = -100
x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)
y_nested = torch.nested.as_nested_tensor(y_list, layout=torch.jagged)
print(x_nested.shape, y_nested.shape)
criterion = FocalLoss(gamma=0.7, weight=torch.tensor([1,1,1,1,1]), ignore_index=-100)
criterion(x_nested, y_nested)
torch.Size([2, 5, j7]) torch.Size([2, j8])
tensor(1.7631)
batch_size = 2
n_classes = 5
max_len = 480

# Create sequences of different lengths
seq_lens = torch.randint(50, max_len, (batch_size,))

# Create input tensors with different sequence lengths
x_list = [torch.randn(n_classes, length) for length in seq_lens]
y_list = [torch.randn(n_classes, length) for length in seq_lens]
x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)
y_nested = torch.nested.as_nested_tensor(y_list, layout=torch.jagged)

criterion = FocalLoss(gamma=0.7, weight=None, ignore_index=0)
criterion(x_nested, y_nested)
tensor(2.2618)

source

Momentum


def Momentum(
    backbone:Module, loss:Callable, batchsize:int=16, steps:int=10, rate:float=0.999
):

adapted from torchsurv: https://github.com/Novartis/torchsurv/blob/main/src/torchsurv/loss/momentum.py Survival framework to momentum update learning to decouple batch size during model training. Two networks are concurrently trained, an online network and a target network. The online network outputs batches are concanetaed and used by the target network, so it virtually increase its batchsize.

The target network (k)is updated using an exponential momentum average (EMA) using parameters from the online network (q). The online network is trained using a memory bank of previously computed log hazards, but only tracking loss from current batch.


source

KLDivLoss


def KLDivLoss(
    reduction:str='mean'
):

Kullback-Leibler Divergence Loss with masking for ignore_index. Handles soft labels with ignore_index marked as -100.

Args: logits: [bs x n_classes x pred_labels] - model predictions targets: [bs x n_classes x soft_labels] - soft labels, with ignore_index positions marked as 0

x = torch.randn(4,5,10)
y = torch.softmax(torch.randn(4,5,10), dim=1)

KLDivLoss()(x,y)
tensor(0.6321)
batch_size = 5
n_classes = 5
max_len = 480

# Create sequences of different lengths
seq_lens = torch.randint(50, max_len, (batch_size,))

# Create input tensors with different sequence lengths
x_list = [torch.randn(n_classes, length) for length in seq_lens]
y_list = [torch.softmax(torch.randn(n_classes, length), dim=0) for length in seq_lens]
x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)
y_nested = torch.nested.as_nested_tensor(y_list, layout=torch.jagged)

criterion = KLDivLoss()
criterion(x_nested, y_nested)
tensor(0.6608)
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import torch.distributed.nn

# from util import misc


# def compute_cross_entropy(p, q):
#     q = F.log_softmax(q, dim=-1)
#     loss = torch.sum(p * q, dim=-1)
#     return - loss.mean()


# def stablize_logits(logits):
#     logits_max, _ = torch.max(logits, dim=-1, keepdim=True)
#     logits = logits - logits_max.detach()
#     return logits


# @torch.no_grad()
# def concat_all_gather(tensor):
#     """
#     Performs all_gather operation on the provided tensors.
#     *** Warning ***: torch.distributed.all_gather has no gradient.
#     """
#     tensors_gather = [torch.ones_like(tensor)
#                       for _ in range(torch.distributed.get_world_size())]
#     torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

#     output = torch.cat(tensors_gather, dim=0)
#     return output


# class MultiPosConLoss(nn.Module):
#     """
#     Multi-Positive Contrastive Loss: https://arxiv.org/pdf/2306.00984.pdf
#     """

#     def __init__(self, temperature=0.1):
#         super(MultiPosConLoss, self).__init__()
#         self.temperature = temperature
#         self.logits_mask = None
#         self.mask = None
#         self.last_local_batch_size = None

#     def forward(self, x, y):
#         feats = x    # feats shape: [B, D]
#         labels = y    # labels shape: [B]

#         feats = F.normalize(feats, dim=-1, p=2)
#         local_batch_size = feats.size(0)

#         all_feats = torch.cat(torch.distributed.nn.all_gather(feats), dim=0)
#         all_labels = concat_all_gather(labels)  # no gradient gather

#         # compute the mask based on labels
#         if local_batch_size != self.last_local_batch_size:
#             mask = torch.eq(labels.view(-1, 1),
#                             all_labels.contiguous().view(1, -1)).float().to(device)
#             self.logits_mask = torch.scatter(
#                 torch.ones_like(mask),
#                 1,
#                 torch.arange(mask.shape[0]).view(-1, 1).to(device) +
#                 local_batch_size * misc.get_rank(),
#                 0
#             )

#             self.last_local_batch_size = local_batch_size
#             self.mask = mask * self.logits_mask

#         mask = self.mask

#         # compute logits
#         logits = torch.matmul(feats, all_feats.T) / self.temperature
#         logits = logits - (1 - self.logits_mask) * 1e9

#         # optional: minus the largest logit to stablize logits
#         logits = stablize_logits(logits)

#         # compute ground-truth distribution
#         p = mask / mask.sum(1, keepdim=True).clamp(min=1.0)
#         loss = compute_cross_entropy(p, logits)

#         return loss