PFTSleep

Overview

PFTSleep is a Python package for sleep stage classification using a pre-trained foundational transformer. The repository is built using nbdev, which means the package is developed in Jupyter notebooks.

See the publication in SLEEP and original preprint for more details.

Install

pip install PFTSleep

Inference

To perform inference on unseen data (using EDF file paths as input), use the pft_sleep_inference.sh, pft_sleep_inference.py and pftsleep_inference_config.yaml files.

Please create an account on Hugging Face and request access to the models here. You will also need to create a personal access token, with read access. Read more here.

Following, download the encoder and sleep stage classifier models with:

from pftsleep.inference import download_pftsleep_models

download_pftsleep_models(models_dir='', token=YOUR_HF_TOKEN)

You will also be prompted to download the files when running the inference script, if they are not in models_dir.

After the models are downloaded, update the pftsleep_inference_config.yaml file with a path to your edf or edf directory. You can pass both a single edf file or a directory of edfs. You can also use glob syntax if specifying edfs within single sub directories (e.g. /path/to/base/directory/**/).

Unfortunately, due to differing naming conventions for signal channels, if you are passing multiple edfs, but they have different channel names, the dataloader will fail. It is recommended in this case to rewrite the edfs to a consistent channel name format, or perform inference on them one by one.

If a specific channel is not available for a given edf or set of edfs, pass the keyword “null” or “dummy” to that channels name parameter in the yaml if you’d like to see how the model performs with that channel set as all zeros.

The model expects referenced EEG, Left EOG, EMG, and ECG channels. If your channels are unreferenced, you may pass the corresponding reference channels. The model was trained with the following referenced channels (priority was given to the ones listed first in the below list):

EEG: C4-M1 or C3-M2
EMG: Chin1-Chin2 or Chin1-Chin3
ECG: Augmented lead 2 (or ECG (LL) - ECG (RA))
Left EOG: E1-M2

Check the slumber.py source code for NSRR specific channels (under the SHHS_CHANNELS, MROS_CHANNELS, WSC_CHANNELS, etc. variables) if the above is confusing.

For the device parameter, use “cpu” (slowest), GPU (e.g. “cuda:0”), or MPS (“mps” for Mac OS X).

Prediction logits of shape [bs x 5 x 960] are outputted. The first dimension indicates individual sleep stage logits where the 0 index is wake, 1 is N1, 2 is N2, 3 is N3, and 4 is REM. To retrieve probabilities, use torch’s softmax function on the 1st dimension of the tensor. Note that the model expects an 8 hour input and returns 960 class predictions for each 30 second sleep epoch within the 8 hours. If the sleep study is longer than 8 hours, stages after the 8 hour time point will not be predicted. If the sleep study is shorter than 8 hours, stages predicted after the true length of the study should be ignored (despite the model outputted a size of 960). Please file an issue if this becomes a major problem and we can work on a solution. We are also working on models that will accept variable length input.

To finally run the predictions on a single edf file or directory of edf files, give permissions to the shell script:

chmod +x pftsleep_inference.sh

Then run it, passing the config file as the main argument:

./pftsleep_inference.sh pftsleep_inference_config.yaml

Predictions will be output to a torch tensor file .pt at the location specified in the yaml.

Additionally, you can pass the save_hypjson parameter in the yaml as true. This will perform softmax, max index selecting, and save the predictions as a HYPJSON file (with the same filename as the edf file, + ’_pftsleep.HYPJSON’). You can also use the write_pred_to_hypjson function for individual files. Sleep stages are mapped in this function to typical HYPJSON standards (for example, REM is mapped to the integer “5”).

Repository Structure and Usage

This is an nbdev repository, which means the package is developed in Jupyter notebooks located in the nbs/ directory. Any modifications or additions to the PFTSleep package should be made by editing these notebooks.

To build the package, run nbdev_prepare in the terminal. This will generate the PFTSleep package in the PFTSleep/ directory and all python modules, which can be imported and used in other Python projects.

To add new functionality, create a new notebook or add to exisitng in the nbs/ directory and follow the instructions in the nbdev documentation to add the new functionality. Then, run nbdev_prepare to generate the PFTSleep package with the new functionality.

Directory Structure:

  • nbs/: Contains the source notebooks that generate the Python package
  • jobs/: Contains processing and training scripts
    • apples/: Processing scripts for the apples dataset
    • mesa/: Processing scripts for the mesa dataset
    • mros/: Processing scripts for the mros dataset
    • shhs/: Processing scripts for the shhs dataset
    • wsc/: Processing scripts for the wsc dataset
    • model_training/:
      • train_transformer.py: Trains the initial foundational transformer model
      • train_classifier.py: Trains the probing head for sleep stage classification

Each dataset directory contains scripts to: - Create hypnogram CSVs from annotations - Build zarr files from EDF files - Process and standardize the data for model training

Model Training Pipeline

  1. Foundation Model Training (jobs/model_training/train_transformer.py)
    • Trains the base transformer model on sleep data zarr files
    • Creates general purpose representations of sleep signals
  2. Probe Training (jobs/model_training/train_classifier.py)
    • Trains a classification head on top of the foundation model

Technical Details

  • We trained the foundational model on 2x H100 80gb GPUs using PyTorch Lightning.
  • We monitored training using the Weights and Biases platform.
  • We performed hyperparameter optimization using Optuna.

Citation

If you use PFTSleep in your research, please cite:

@ARTICLE{Fox2025-zc,
  title     = "A foundational transformer leveraging full night, multichannel
               sleep study data accurately classifies sleep stages",
  author    = "Fox, Benjamin and Jiang, Joy and Wickramaratne, Sajila and
               Kovatch, Patricia and Suarez-Farinas, Mayte and Shah, Neomi A
               and Parekh, Ankit and Nadkarni, Girish N",
  journal   = "Sleep",
  publisher = "Oxford University Press (OUP)",
  month     =  mar,
  year      =  2025,
  language  = "en"
}