JEPA For Sleep

maybe

Torch


source

create_masks


def create_masks(
    x, patch_size, patch_stride, context_mask_range, target_mask_range, melt_channels_to_batch:bool=False,
    return_nested:bool=False
):

source

apply_masks


def apply_masks(
    x, masks
):

source

TSTBlock


def TSTBlock(
    d_model, n_heads, d_ff:int=256, attn_dropout:int=0, dropout:float=0.0, bias:bool=True, activation:str='gelu',
    pre_norm:bool=False, rotary_pes:bool=False
):

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool


source

JEPABlock


def JEPABlock(
    dim, num_heads, mlp_ratio:float=4.0, qkv_bias:bool=False, qk_scale:NoneType=None, drop:float=0.0,
    attn_drop:float=0.0, act_layer:type=GELU, norm_layer:type=LayerNorm, rotary_pes:bool=False
):

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool

max_len = 5000
batch_size = 2
seq_lens = torch.randint(4000, max_len, (batch_size,))

# Create input tensors with different sequence lengths
x_list = [torch.randn(7, length) for length in seq_lens]
x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)
x = torch.randn(2, 7, 5000)
masks, non_masks = create_masks(x_nested, 10, 10, (0.05, 0.08), (0.0, 0.1), melt_channels_to_batch=True, return_nested=True)
masks.shape
torch.Size([14, j94])

source

Encoder


def Encoder(
    c_in, num_patches, patch_size, patch_stride, d_model, nhead, num_layers, use_tst_block:bool=False,
    shared_embedding:bool=True, pe_type:str='tAPE', mlp_ratio:float=4.0, qkv_bias:bool=True, qk_scale:NoneType=None,
    drop_rate:float=0.0, attn_drop_rate:float=0.0, norm_layer:type=LayerNorm, jepa:bool=True,
    embed_activation:GELU=GELU(approximate='none'), init_std:float=0.02, tokenizer_type:str='simple',
    tokenizer_kwargs:dict={}
):

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool


source

jepa_mse_loss


def jepa_mse_loss(
    pred, target_ema
):

Compute MSE loss between predictions and targets.

Args: pred: Predictions [nested tensor or regular tensor] target_ema: Target embeddings from EMA encoder

Returns: Scalar loss value


source

Predictor


def Predictor(
    num_patches, encoder_embed_dim:int=128, predictor_embed_dim:int=128, nhead:int=2, num_layers:int=1,
    use_tst_block:bool=False, pe_type:str='tAPE', mlp_ratio:float=4.0, qkv_bias:bool=True, qk_scale:NoneType=None,
    drop_rate:float=0.0, attn_drop_rate:float=0.0, norm_layer:type=LayerNorm,
    embed_activation:GELU=GELU(approximate='none'), init_std:float=0.02,
    c_in_mask_tokens:int=1, # number of channels in the encoder (if treating channels sep)
    shuffle:bool=True
):

Base class for all neural network modules.

Your models should also subclass this class.

Modules can also contain other Modules, allowing them to be nested in a tree structure. You can assign the submodules as regular attributes::

import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

Submodules assigned in this way will be registered, and will also have their parameters converted when you call :meth:to, etc.

.. note:: As per the example above, an __init__() call to the parent class must be made before assignment on the child.

:ivar training: Boolean represents whether this module is in training or evaluation mode. :vartype training: bool


source

JEPASimpleLightning


def JEPASimpleLightning(
    learning_rate, train_size, batch_size, n_gpus, patchtsjepa_encoder_kwargs, patchtsjepa_predictor_kwargs,
    num_nodes:int=1, weight_decay:float=0.04, use_weight_decay_scheduler:bool=False, final_weight_decay:float=0.4,
    epochs:int=100, optimizer_type:str='adamw', scheduler_type:str='OneCycle', target_mask_range:tuple=(0.05, 0.3),
    context_mask_range:tuple=(0.5, 1.0), mask_block_range:tuple=(1, 30), ema_decay:float=0.996,
    scheduler_kwargs:dict={}, transforms:NoneType=None, loss_fn:function=jepa_mse_loss, linear_probe:bool=False
):

Hooks to be used in LightningModule.

encoder_kwargs = dict(
         c_in=7,
            num_patches=30,
            patch_size=128,
            patch_stride = 128,
            d_model=512,
            nhead=8,
            num_layers=3,
            mlp_ratio=4.0,
            qkv_bias=True,
            qk_scale=None,
            drop_rate=0.0,
            attn_drop_rate=0.0,
            norm_layer=nn.LayerNorm,
            jepa=True,
            embed_activation=nn.GELU(),
            tokenizer_type='linear',
            pe_type='tAPE',
            tokenizer_kwargs=dict(bottleneck_channels = 32, kernel_size=64, depth=1, residual = True, bottleneck=True),
            use_tst_block=True,
            shared_embedding=False,
)

predictor_kwargs = dict(num_patches=30,
    encoder_embed_dim=512,
    predictor_embed_dim=128,
    nhead=4,
    pe_type='tAPE',
    num_layers=2,)

jepa_lightning = JEPASimpleLightning(
    learning_rate=0.001,
    train_size=1000,
    batch_size=10,
    mask_block_range=(1, 1),
    n_gpus=1,
    patchtsjepa_encoder_kwargs=encoder_kwargs,
    patchtsjepa_predictor_kwargs=predictor_kwargs,
    linear_probe=True,
)

x = (torch.randn(2, 7, 128*30), None)

batch_size = 2
n_vars = 7
max_len = 128*30

seq_lens = torch.randint(128*20, max_len, (batch_size,))

# Create input tensors with different sequence lengths
x_list = [torch.randn(n_vars, length) for length in seq_lens]
x_nested = torch.nested.as_nested_tensor(x_list, layout=torch.jagged)
print(x_nested.shape)
#o = jepa_lightning(x_nested)
x = (x_nested, None)
o = jepa_lightning.validation_step(x, 0)
# #from pytorch_lightning.utilities.model_summary import summarize

# summarize(jepa_lightning)