PFTSleep
Overview
PFTSleep is a Python package for sleep stage classification using a pre-trained foundational transformer. The repository is built using nbdev, which means the package is developed in Jupyter notebooks.
See the publication in SLEEP and original preprint for more details.
Install
pip install PFTSleep
Inference
To perform inference on unseen data (using EDF file paths as input), use the pft_sleep_inference.sh
, pft_sleep_inference.py
and pftsleep_inference_config.yaml
files.
Please create an account on Hugging Face and request access to the models here. You will also need to create a personal access token, with read access. Read more here.
Following, download the encoder and sleep stage classifier models with:
from pftsleep.inference import download_pftsleep_models
download_pftsleep_models(models_dir='', token=YOUR_HF_TOKEN)
You will also be prompted to download the files when running the inference script, if they are not in models_dir
.
After the models are downloaded, update the pftsleep_inference_config.yaml
file with a path to your edf or edf directory. You can pass both a single edf file or a directory of edfs. You can also use glob syntax if specifying edfs within single sub directories (e.g. /path/to/base/directory/**/).
Unfortunately, due to differing naming conventions for signal channels, if you are passing multiple edfs, but they have different channel names, the dataloader will fail. It is recommended in this case to rewrite the edfs to a consistent channel name format, or perform inference on them one by one.
If a specific channel is not available for a given edf or set of edfs, pass the keyword “null” or “dummy” to that channels name parameter in the yaml if you’d like to see how the model performs with that channel set as all zeros.
The model expects referenced EEG, Left EOG, EMG, and ECG channels. If your channels are unreferenced, you may pass the corresponding reference channels. The model was trained with the following referenced channels (priority was given to the ones listed first in the below list):
EEG: C4-M1 or C3-M2
EMG: Chin1-Chin2 or Chin1-Chin3
ECG: Augmented lead 2 (or ECG (LL) - ECG (RA))
Left EOG: E1-M2
Check the slumber.py
source code for NSRR specific channels (under the SHHS_CHANNELS
, MROS_CHANNELS
, WSC_CHANNELS
, etc. variables) if the above is confusing.
For the device
parameter, use “cpu” (slowest), GPU (e.g. “cuda:0”), or MPS (“mps” for Mac OS X).
Prediction logits of shape [bs x 5 x 960] are outputted. The first dimension indicates individual sleep stage logits where the 0 index is wake, 1 is N1, 2 is N2, 3 is N3, and 4 is REM. To retrieve probabilities, use torch’s softmax function on the 1st dimension of the tensor. Note that the model expects an 8 hour input and returns 960 class predictions for each 30 second sleep epoch within the 8 hours. If the sleep study is longer than 8 hours, stages after the 8 hour time point will not be predicted. If the sleep study is shorter than 8 hours, stages predicted after the true length of the study should be ignored (despite the model outputted a size of 960). Please file an issue if this becomes a major problem and we can work on a solution. We are also working on models that will accept variable length input.
To finally run the predictions on a single edf file or directory of edf files, give permissions to the shell script:
chmod +x pftsleep_inference.sh
Then run it, passing the config file as the main argument:
./pftsleep_inference.sh pftsleep_inference_config.yaml
Predictions will be output to a torch tensor file .pt
at the location specified in the yaml.
Additionally, you can pass the save_hypjson
parameter in the yaml as true
. This will perform softmax, max index selecting, and save the predictions as a HYPJSON file (with the same filename as the edf file, + ’_pftsleep.HYPJSON’). You can also use the write_pred_to_hypjson
function for individual files. Sleep stages are mapped in this function to typical HYPJSON standards (for example, REM is mapped to the integer “5”).
Repository Structure and Usage
This is an nbdev repository, which means the package is developed in Jupyter notebooks located in the nbs/
directory. Any modifications or additions to the PFTSleep
package should be made by editing these notebooks.
To build the package, run nbdev_prepare
in the terminal. This will generate the PFTSleep
package in the PFTSleep/
directory and all python modules, which can be imported and used in other Python projects.
To add new functionality, create a new notebook or add to exisitng in the nbs/
directory and follow the instructions in the nbdev documentation to add the new functionality. Then, run nbdev_prepare
to generate the PFTSleep
package with the new functionality.
Directory Structure:
nbs/
: Contains the source notebooks that generate the Python packagejobs/
: Contains processing and training scriptsapples/
: Processing scripts for the apples datasetmesa/
: Processing scripts for the mesa datasetmros/
: Processing scripts for the mros datasetshhs/
: Processing scripts for the shhs datasetwsc/
: Processing scripts for the wsc datasetmodel_training/
:train_transformer.py
: Trains the initial foundational transformer modeltrain_classifier.py
: Trains the probing head for sleep stage classification
Each dataset directory contains scripts to: - Create hypnogram CSVs from annotations - Build zarr files from EDF files - Process and standardize the data for model training
Model Training Pipeline
- Foundation Model Training (
jobs/model_training/train_transformer.py
)- Trains the base transformer model on sleep data zarr files
- Creates general purpose representations of sleep signals
- Probe Training (
jobs/model_training/train_classifier.py
)- Trains a classification head on top of the foundation model
Technical Details
- We trained the foundational model on 2x H100 80gb GPUs using PyTorch Lightning.
- We monitored training using the Weights and Biases platform.
- We performed hyperparameter optimization using Optuna.
Citation
If you use PFTSleep in your research, please cite:
@ARTICLE{Fox2025-zc,
title = "A foundational transformer leveraging full night, multichannel
sleep study data accurately classifies sleep stages",
author = "Fox, Benjamin and Jiang, Joy and Wickramaratne, Sajila and
Kovatch, Patricia and Suarez-Farinas, Mayte and Shah, Neomi A
and Parekh, Ankit and Nadkarni, Girish N",
journal = "Sleep",
publisher = "Oxford University Press (OUP)",
month = mar,
year = 2025,
language = "en"
}