preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len] mask: [bs x num_patch x n_vars] padding_mask: [bs x num_patch]
preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len] mask: [bs x num_patch x n_vars] padding_mask: [bs x num_patch]
preds: [bs x num_patch x n_vars x patch_len] targets: [bs x num_patch x n_vars x patch_len] mask: [bs x num_patch x n_vars] padding_mask: [bs x num_patch]
adapted from tsai, weighted multiclass focal loss https://github.com/timeseriesAI/tsai/blob/bdff96cc8c4c8ea55bc20d7cffd6a72e402f4cb2/tsai/losses.py#L116C1-L140C20
Kullback-Leibler Divergence Loss with masking for ignore_index. Handles soft labels with ignore_index marked as -100.
Args: logits: [bs x n_classes x pred_labels] - model predictions targets: [bs x n_classes x soft_labels] - soft labels, with ignore_index positions marked as -100 or [bs x n_labels] - hard labels
x = torch.randn(4,5,10)y = torch.randint(0,5, size=(4,10))y_og = y.clone()y[0,0] =-100KLDivLoss(ignore_index=-100)(x,y)