PyTorch Lightning utilities that make it easier to train and evaluate deep models for the Neural Latents Benchmark. Key components include a preprocessing script, LightningDataModule
, example LightningModule
, logging callbacks, and training scripts. The provided callbacks report competition evaluation metrics, inferred firing rate rasters, and latent state trajectories.
If you use or reference nlb-lightning
in published work, please cite:
Andrew R. Sedler. (2022). PyTorch Lightning utilities for the Neural Latents Benchmark (v0.1.0). Zenodo. https://doi.org/10.5281/zenodo.6811601
Clone the entire nlb-lightning
repo, including submodules. Then, create and activate a conda
environment and install the nlb_tools
and nlb_lightning
packages.
git clone --recurse-submodules https://github.com/arsedler9/nlb-lightning.git
cd nlb-lightning
conda create --name nlb-lightning python=3.9
conda activate nlb-lightning
pip install -e nlb_tools
pip install -e .
First, install the dandi
CLI (pip install dandi
) and download each of the NLB datasets into the same folder (e.g. dandi download DANDI:000128/0.220113.0400
for mc_maze
). We will refer to this folder as DATA_HOME
. Edit the .env
file in the nlb-lightning
repo to set this path appropriately. Download commands can be found on the DANDI Archive page for each dataset (links on the NLB Datasets page).
We define a main
function in the scripts/preprocess.py
module that preprocesses and saves data for training, using the functions in nlb_tools.make_tensors
. This module can be run as a script to generate all 28 splits used for NLB. The preprocessed data will be stored in the PREP_HOME
folder (edit .env
accordingly), with a unique folder for each split ({dataset_name}-{bin_width:02}ms-{phase}
). Within each split folder, the filenames used for each type of data are specified by TRAIN_INPUT_FILE
, EVAL_INPUT_FILE
, and EVAL_TARGET_FILE
in .env
. The NLBDataModule
will load data directly from these files.
New models should be implemented as subclasses of the pytorch_lightning.LightningModule
, as demonstrated in the SequentialAutoencoder
. See the pytorch_lightning
docs for more details. There are two requirements for compatibility with the datamodule and evaluation and plotting callbacks. If you are not using the callbacks, feel free to ignore these requirements.
- On
val
-phase data, the batch tuple must containinput_data
andrecon_data
in the first and second positions andbehavior
in the last position. Other tensors may be included betweenrecon_data
andbehavior
. Ontest
-phase data, thevalidation_step
batch must containinput_data
in the first position. - The
forward
method must takeinput_data
as its first argument, along with any extra dataset tensors (betweenrecon_data
andbehavior
, as mentioned in step 1). The first and second outputs must be a tensor of rate estimates (including heldin/heldout and observed/forward) and a tensor of latent states, respectively.
We provide an example training script for the SequentialAutoencoder
at scripts/train_sae.py
. Your script may be different, depending on the hyperparameters of your model and the specific callbacks or other pytorch_lightning
features that you choose to use. Training logs and checkpoints will be stored in the RUNS_HOME
directory that you specify in .env
. Inside RUNS_HOME
, each run will be stored at {run_tag}/{data_tag}
, where data_tag
is resolved using the f-string in the Preprocessing section. After training, the output of each model in a run_tag
directory is written to a shared submission-{phase}.h5
file, which can be uploaded directly to EvalAI.
For those with the resources to train multiple models in parallel, we provide a second training script (scripts/train_sae_all.py
) which uses ray.tune
to train a model on each of the 28 NLB splits and save all rate estimates to the relevant submission file.
If you use this code in published work, please cite