Skip to content

Latest commit

 

History

History
90 lines (72 loc) · 4.04 KB

File metadata and controls

90 lines (72 loc) · 4.04 KB

SuPreM

For Anatomical Segmentation on the (subset of) TotalSegmentator


Follow on Twitter

0. Create a virtual environment (optional)
[Optional] Install Anaconda on Linux
wget https://repo.anaconda.com/archive/Anaconda3-2024.06-1-Linux-x86_64.sh
bash Anaconda3-2024.06-1-Linux-x86_64.sh -b -p ./anaconda3
./anaconda3/bin/conda init
source ~/.bashrc
conda create -n suprem python=3.8 -y
source activate suprem
1. Clone the GitHub repository
git clone https://github.com/MrGiovanni/SuPreM
cd SuPreM/
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install monai[all]==0.9.0
pip install pip==23.3.1
pip install -r requirements.txt
2. Download the pre-trained Swin UNETR checkpoint
cd target_applications/totalsegmentator/pretrained_weights/
wget https://huggingface.co/MrGiovanni/SuPreM/resolve/main/supervised_suprem_swinunetr_2100.pth
wget https://huggingface.co/MrGiovanni/SuPreM/resolve/main/supervised_suprem_segresnet_2100.pth
wget https://huggingface.co/MrGiovanni/SuPreM/resolve/main/supervised_suprem_unet_2100.pth
cd ../../../
3. Download the TotalSegmentator dataset

from Zenodo (1,228 subjects) and save it to /path/to/your/data/TotalSegmentator

4. Fine-tune the pre-trained Swin UNETR on TotalSegmentator
# Single GPU

cd target_applications/totalsegmentator/
RANDOM_PORT=$((RANDOM % 64512 + 1024))
datapath=/scratch/zzhou82/data/Totalsegmentator_dataset/Totalsegmentator_dataset/ # change to /path/to/your/data/TotalSegmentator
arch=swinunetr # support swinunetr, unet, and segresnet
suprem_path=pretrained_weights/supervised_suprem_swinunetr_2100.pth
target_task=vertebrae
num_target_class=25
num_target_annotation=64

python -W ignore -m torch.distributed.launch --nproc_per_node=1 --master_port=$RANDOM_PORT train.py --dist  --model_backbone $arch --log_name efficiency.$arch.$target_task.number$num_target_annotation --map_type $target_task --num_class $num_target_class --dataset_path $datapath --num_workers 8 --batch_size 2 --pretrain $suprem_path --percent $num_target_annotation
5. Evaluate the performance per class
# Single GPU

cd target_applications/totalsegmentator/
RANDOM_PORT=$((RANDOM % 64512 + 1024))
datapath=/scratch/zzhou82/data/Totalsegmentator_dataset/Totalsegmentator_dataset/ # change to /path/to/your/data/TotalSegmentator
checkpoint_path=out/efficiency.$arch.$target_task.number$num_target_annotation/best_model.pth
target_task=vertebrae
num_target_class=25
num_target_annotation=64

python -W ignore -m torch.distributed.launch --nproc_per_node=1 --master_port=$RANDOM_PORT test.py --dist  --model_backbone $arch --log_name efficiency.$arch.$target_task.number$num_target_annotation --map_type $target_task --num_class $num_target_class --dataset_path $datapath --num_workers 8 --batch_size 2 --pretrain $checkpoint_path --train_type efficiency