This respository contains the code accompanying the paper "Incorporating knowledge of plates in batch normalization improves generalization of deep learning for microscopy images" by Alex Lin and Alex Lu. Please find licensing details in the file license.txt
.
- Our code is built on the PyTorch library.
- The majority of the code is packaged into modules in the
biomass
directory. Here, you will find specific implementations of models, data loaders, data augmentation transforms, etc. To simply run the code, you do not need to directly edit anything in this directory. - The Python scripts in the root directory (e.g.
train_erm.py
,train_simclr.py
, etc.) are the outer functions that run the experiments for the paper. To run a particular experiment, executepython [script_name].py
(e.g.python train_erm.py
) in the command line. - Each script has an associated YAML file in the
configs
directory, so you do not need to edit the scripts themselves to vary experimental settings. Within each YAML config, there is a dictionary of hyperparameter configurations for each experiment. You can vary these to re-run our experiments with different settings. Our codebase uses Hydra for experiment management; for more information on how Python scripts and YAML files interact, please consult the Hydra documentation. - When you run a script, our codebase will create three things:
- In the
outputs
directory, it will create a directory named by the timestamp in which you ran the script. The YAML config file associated with the run will be copied and dumped in this directory so you can refer back to what hyperparameters you used at that time. - In the
runs
directory, it will dump a TensorBoard charting the progress of model training. Within the tensorboard, we record metrics such as accuracy, loss, gradient dynamics, etc. - In the
checkpoints
directory, it will dump a checkpoint of the PyTorch model associated with the experiment.
- In the
- The
requirements.txt
file contains all of the Python packages and associated versions that we used. You can runpip install -r requirements.txt
in the command line to automatically install the correct versions for each package. - The
misc
directory contains miscellaneous files needed for running certain experiments.
- To reproduce our supervised learning experiments, use the script
train_erm.py
. (Note that ERM stands for empirical risk minimization, i.e. another term for supervised learning.) - The current associated config file
configs/train_erm.yaml
is set up to run supervised learning with BEN, our batch effects correction method. - To run vanilla supervised learning (i.e. without BEN), make the following edits to
configs/train_erm.yaml
: deletetrain_groupby
, deletemax_groups
, seteval_plate_sampler: False
, seteval_batch_size: 75
(or whatever batch size you prefer), and setuse_train_at_eval: False
. - The first time you run
train_erm.py
, note that our script should automatically download the RxRx1-Wilds dataset for you from the Wilds package.
- To reproduce our self-supervised learning experiments, use the script
train_simclr.py
for training the base model andtrain_classifier.py
for fitting the linear classifier on the learned representations. - Thus, in the
train_classifier.yaml
file, there is an argumentmodel_path: xxxxx
that needs to point to a saved checkpoint obtained from runningtrain_simclr.py
. Make sure to manually set the correct path fortrain_classifier.yaml
after runningtrain_simclr.py
. - In both
train_simclr.yaml
andtrain_classifier.yaml
, there is an argumentimg_dir
that needs to point to a directory of cropped cells for RxRx1-Wilds. A zipped file of this directory can be downloaded at this link (note it is about ~4 GB in size): https://zenodo.org/record/7272553#.Y2KkNuzMJTZ - The current config file
train_simclr.yaml
is setup to train the vanilla SimCLR algorithm (without BEN). To run SimCLR + BEN, simply change the argumentsampler: random
tosampler: plate
. Then, to apply BEN while training the classifier, go totrain_classifier.yaml
and change the argumentssampler: random
->sampler: plate
andmodel_train_mode: False
->model_train_mode: True
. - To use MinCLR (i.e. multiple instance constrastive learning, a new method that we developed) instead of SimCLR, simply go to
train_simclr.yaml
and changemode: random_single
tomode: random
(this will define positive anchors as random cells from the same image instead of random augmentations of the same single cell). To run MinCLR + BEN, follow the aforementioned instructions for SimCLR + BEN. - To increase the number of positive anchors during training, go to
train_simclr.yaml
and changenum_img: 2
to any other value (e.g.num_img: 5
). Make sure that this value also matches the argumentnum_views: 2
(e.g.num_views: 5
). - To train representations from a cell-level supervised classifier (instead of a self-supervised learner), use the script
train_supervised_cell.py
. The current YAML filetrain_supervised_cell.yaml
is designed for standard supervised learning (without BEN). To use BEN, simply changesampler: random
tosampler: plate
.
- To reproduce our transfer learning experiments, use the notebook
pybbbc.ipynb
. - To obtain the dataset for BBBC021, use the following package: https://github.com/giacomodeodato/pybbbc and follow the instructions for "Data download" and "Dataset creation". Note that this can take several hours. Afterwards, the data should be dumped by default into a directory called
~/.cache/
. - You also need to obtain two files that list the nuclei centers for cells in this dataset. You can download these files (called
supplement_Object.txt
andsupplement_Image.txt
) from the supplementary files of this paper: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3884769/ (under the title "Data S2"). - Since this is transfer learning, you should also have a link to the checkpoint of a pre-trained model on the RxRx1-Wilds (cell level) dataset, obtained by running a self-supervised learning experiment (for example). This checkpoint will need to be loaded into the state dict of the PyTorch model (see notebook).
- Towards the end of the notebook, we calculate both NSC and NSCB accuracy (see the paper for more details on these metrics).