SleepJEPA
Overview
SleepJEPA is a Python package for at-home sleep study data to classify sleep stages, sleepiness, and estimate long-term disease risk. SleepJEPA was trained with the JEPA architecture, modified for signal data and sleep studies. The repository is built using nbdev, which means the package is developed in Jupyter notebooks.
See the publication in tbd…
Install
pip install sleepjepaRepository Structure and Usage
This is an nbdev repository, which means the package is developed in Jupyter notebooks located in the nbs/ directory. Any modifications or additions to the sleepjepa package should be made by editing these notebooks.
To build the package, run nbdev_prepare in the terminal. This will generate the sleepjepa package in the sleepjepa/ directory and all python modules, which can be imported and used in other Python projects.
To add new functionality, create a new notebook or add to exisitng in the nbs/ directory and follow the instructions in the nbdev documentation to add the new functionality. Then, run nbdev_prepare to generate the sleepjepa package with the new functionality.
Directory Structure:
nbs/: Contains the source notebooks that generate the Python packagejobs/: Contains processing and training scriptsapples/: Processing scripts for the apples datasethuman_sleep_project/: Processing scripts for the human sleep project datasetjepa/: Training script for the JEPA sleep modelmesa/: Processing scripts for the mesa datasetmnc/: Processing scripts for the mnc datasetmros/: Processing scripts for the mros datasetshhs/: Processing scripts for the shhs datasetwsc/: Processing scripts for the wsc datasetsleep_outcomes/:config/: Config files for training outcomes/modelstrain_age.py: Model training script for SleepJEPA representations to estimate chronological age predictiontrain_obj_sleepiness.py: Model training script for SleepJEPA representations to estimate objective sleepiness and narcolepsy outcomestrain_sleep_stages.py: Model training script for SleepJEPA representations to classify sleep stagestrain_demographics_long_term_outcome.py: Model training script to train demographics only model for long term disease risk estimationtrain_sleep_long_term_outcome.py: Model training script to train SleepJEPA representations +/- demographicss for long term disease risk estimation
Technical Details
- We trained the foundational model on H100nvl GPUs using PyTorch Lightning.
- We monitored training using the Weights and Biases platform.