Skip to content

Latest commit

 

History

History
102 lines (71 loc) · 3.32 KB

README.md

File metadata and controls

102 lines (71 loc) · 3.32 KB

PARTICLE

This is the code for the training and evaluation of part-contrast training proposed in

PARTICLE: Part Discovery and Contrastive Learning for Fine-grained Recognition

Oindrila Saha, Subhransu Maji

Preparation

Create a conda environment with the specifications

conda create --name <env name> --file spec-file.txt
conda activate <env name>

-> Make a folder named data_dir here

-> Download PASCUB dataset from here in data_dir

-> Download the full CUB dataset (images and segmentations) from here to data_dir and extract images into single folder by:

cd <path to cub data>/images/ 
for folder in *; do; mv $folder/* ../images_extracted/.; done

-> Download OID dataset from here in data_dir

-> Download pretained checkpoints from here into pretrained_models folder inside ssl_training folder

PARTICLE Training

cd ssl_training

ResNeT Variation - DetCon init

First generate clusters using:

python generate_clusters_resnet.py --dataset <dataset type - “birds” or “aircrafts”> --save_dir <dir to save part cluster masks>

Next train ResNet PARTICLE:

python train_particle_resnet.py --dataset <dataset type - “birds” or “aircrafts”> --seg_dir <dir of masks generated in previous step> --save_path <path to save models>

ViT Variation - DINO init

First generate clusters using:

python generate_clusters_vit.py --dataset <dataset type - “birds” or “aircrafts”> --save_dir <dir to save part cluster masks>

Next train ViT PARTICLE

python -m torch.distributed.launch --nproc_per_node=8 train_particle_vit.py --dataset <dataset type - “birds” or “aircrafts”> --seg_dir <dir of masks generated in previous step> —-output_dir <path to save models>

PARTICLE Evaluation

cd evaluation

Linear Evaluation

Test on classifcation using the checkpoints obtained by training PARTICLE:

python test_linear.py --arch <architecture - “resnet50” or "vit_small"> --dataset <dataset type - “birds” or “aircrafts”> --pretrained_weights <trained particle checkpoint>

PARTICLE models should give following score on running this code:

Base SSL DetCon DINO
Birds 40.88 84.15
Aircrafts 43.99 73.59

Few-Shot Part Segmentation

Train on segmentation downstream task using the checkpoints obtained by training PARTICLE:

python train_fcn.py —-arch <architecture - “res50” or “dino”> —-dataset <dataset type - “birds” or “aircrafts”> —-ckpt <trained particle checkpoint> —-save_path <path to save fcn set models”

Find cross-validation mIoU using:

python test_miou.py —-arch <architecture - “res50” or “dino”> —-dataset <dataset type - “birds” or “aircrafts”> —-ckpt_dir <folder of seg ckpts from previous step >

PARTICLE models should give following score on running this code:

Base SSL DetCon DINO
Birds 49.23 50.59
Aircrafts 58.95 61.68