-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
268 additions
and
87 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Multimodal Token Fusion for Vision Transformers | ||
|
||
By Yikai Wang, Xinghao Chen, Lele Cao, Wenbing Huang, Fuchun Sun, Yunhe Wang. | ||
|
||
[**[Paper]**](https://arxiv.org/pdf/2204.08721.pdf) | ||
|
||
This repository is a PyTorch implementation of "Multimodal Token Fusion for Vision Transformers", in CVPR 2022. | ||
|
||
<div align="center"> | ||
<img src="./figs/framework.png" width="960"> | ||
</div> | ||
|
||
Homogeneous predictions, | ||
<div align="center"> | ||
<img src="./figs/homogeneous.png" width="720"> | ||
</div> | ||
|
||
Heterogeneous predictions, | ||
<div align="center"> | ||
<img src="./figs/heterogeneous.png" width="720"> | ||
</div> | ||
|
||
|
||
## Datasets | ||
|
||
For semantic segmentation task on NYUDv2 ([official dataset](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html)), we provide a link to download the dataset [here](https://drive.google.com/drive/folders/1mXmOXVsd5l9-gYHk92Wpn6AcKAbE0m3X?usp=sharing). The provided dataset is originally preprocessed in this [repository](https://github.com/DrSleep/light-weight-refinenet), and we add depth data in it. | ||
|
||
For image-to-image translation task, we use the sample dataset of [Taskonomy](http://taskonomy.stanford.edu/), where a link to download the sample dataset is [here](https://github.com/alexsax/taskonomy-sample-model-1.git). | ||
|
||
Please modify the data paths in the codes, where we add comments 'Modify data path'. | ||
|
||
|
||
## Dependencies | ||
``` | ||
python==3.6 | ||
pytorch==1.7.1 | ||
torchvision==0.8.2 | ||
numpy==1.19.2 | ||
``` | ||
|
||
|
||
## Semantic Segmentation | ||
|
||
|
||
First, | ||
``` | ||
cd semantic_segmentation | ||
``` | ||
|
||
Download the [segformer](https://github.com/NVlabs/SegFormer) pretrained model (pretrained on ImageNet) from [weights](https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia), e.g., mit_b3.pth. Move this pretrained model to folder 'pretrained'. | ||
|
||
Training script for segmentation with RGB and Depth input, | ||
``` | ||
python main.py --backbone mit_b3 -c exp_name --lamda 1e-6 --gpu 0 1 2 | ||
``` | ||
|
||
Evaluation script, | ||
``` | ||
python main.py --gpu 0 --resume path_to_pth --evaluate # optionally use --save-img to visualize results | ||
``` | ||
|
||
Checkpoint models, training logs, mask ratios and the **single-scale** performance on NYUDv2 are provided as follows: | ||
|
||
| Method | Backbone | Pixel Acc. (%) | Mean Acc. (%) | Mean IoU (%) | Download | | ||
|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:| | ||
|[CEN](https://github.com/yikaiw/CEN)| ResNet101 | 76.2 | 62.8 | 51.1 | [Google Drive](https://drive.google.com/drive/folders/1wim_cBG-HW0bdipwA1UbnGeDwjldPIwV?usp=sharing)| | ||
|[CEN](https://github.com/yikaiw/CEN)| ResNet152 | 77.0 | 64.4 | 51.6 | [Google Drive](https://drive.google.com/drive/folders/1DGF6vHLDgBgLrdUNJOLYdoXCuEKbIuRs?usp=sharing)| | ||
|Ours| SegFormer-B3 | 78.7 | 67.5 | 54.8 | [Google Drive](https://drive.google.com/drive/folders/14fi8aABFYqGF7LYKHkiJazHA58OBW1AW?usp=sharing)| | ||
|
||
|
||
Mindspore implementation is available at: https://gitee.com/mindspore/models/tree/master/research/cv/TokenFusion | ||
|
||
## Image-to-Image Translation | ||
|
||
First, | ||
``` | ||
cd image2image_translation | ||
``` | ||
Training script, from Shade and Texture to RGB, | ||
``` | ||
python main.py --gpu 0 -c exp_name | ||
``` | ||
This script will auto-evaluate on the validation dataset every 5 training epochs. | ||
|
||
Predicted images will be automatically saved during training, in the following folder structure: | ||
|
||
``` | ||
code_root/ckpt/exp_name/results | ||
├── input0 # 1st modality input | ||
├── input1 # 2nd modality input | ||
├── fake0 # 1st branch output | ||
├── fake1 # 2nd branch output | ||
├── fake2 # ensemble output | ||
├── best # current best output | ||
│ ├── fake0 | ||
│ ├── fake1 | ||
│ └── fake2 | ||
└── real # ground truth output | ||
``` | ||
|
||
Checkpoint models: | ||
|
||
| Method | Task | FID | KID | Download | | ||
|:-----------:|:-----------:|:-----------:|:-----------:|:-----------:| | ||
| [CEN](https://github.com/yikaiw/CEN) |Texture+Shade->RGB | 62.6 | 1.65 | - | | ||
| Ours | Texture+Shade->RGB | 45.5 | 1.00 | [Google Drive](https://drive.google.com/drive/folders/1vkcDv5bHKXZKxCg4dC7R56ts6nLLt6lh?usp=sharing)| | ||
|
||
## 3D Object Detection (under construction) | ||
|
||
Data preparation, environments, and training scripts follow [Group-Free](https://github.com/zeliu98/Group-Free-3D) and [ImVoteNet](https://github.com/facebookresearch/imvotenet). | ||
|
||
E.g., | ||
``` | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --master_port 2229 --nproc_per_node 4 train_dist.py --max_epoch 600 --val_freq 25 --save_freq 25 --lr_decay_epochs 420 480 540 --num_point 20000 --num_decoder_layers 6 --size_cls_agnostic --size_delta 0.0625 --heading_delta 0.04 --center_delta 0.1111111111111 --weight_decay 0.00000001 --query_points_generator_loss_coef 0.2 --obj_loss_coef 0.4 --dataset sunrgbd --data_root . --use_img --log_dir log/exp_name | ||
``` | ||
|
||
## Citation | ||
|
||
If you find our work useful for your research, please consider citing the following paper. | ||
``` | ||
@inproceedings{wang2022tokenfusion, | ||
title={Multimodal Token Fusion for Vision Transformers}, | ||
author={Wang, Yikai and Chen, Xinghao and Cao, Lele and Huang, Wenbing and Sun, Fuchun and Wang, Yunhe}, | ||
booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
# [ICML2024]GeminiFusion for Multimodal Semantic Segementation on NYUDv2 & SUN RGBD Dataset | ||
|
||
This is the official implementation of our paper "[GeminiFusion: Efficient Pixel-wise Multimodal Fusion for Vision Transformer](Link)". | ||
|
||
Authors: Ding Jia, Jianyuan Guo, Kai Han, Han Wu, Chao Zhang, Chang Xu, Xinghao Chen | ||
|
||
---------------------------- | ||
|
||
## Code List | ||
|
||
We have applied our GeminiFusion to different tasks and datasets: | ||
|
||
* GeminiFusion for Multimodal Semantic Segmentation | ||
* (This branch)[NYUDv2 & SUN RGBD datasets](https://github.com/JiaDingCN/GeminiFusion/tree/main) | ||
* [DeLiVER dataset](https://github.com/JiaDingCN/GeminiFusion/tree/DeLiVER) | ||
* GeminiFusion for Multimodal 3D Object Detection | ||
* [KITTI dataset](https://github.com/JiaDingCN/GeminiFusion/tree/3d_object_detection_kitti) | ||
---------------- | ||
|
||
## Installation | ||
|
||
We build our GeminiFusion on the TokenFusion codebase, which requires no additional installation steps. If any problem about the framework, you may refer to [the offical TokenFusion readme](./README-TokenFusion.md). | ||
|
||
Most of the `GeminiFusion`-related code locate in the following files: | ||
* [models/mix_transformer](models/mix_transformer.py): implement the GeminiFusion module for MiT backbones. | ||
* [models/swin_transformer](models/swin_transformer.py):implement the GeminiFusion module for Swin backbones. | ||
* [mmcv_custom](mmcv_custom): load checkpoints for Swin backbones. | ||
* [main](main.py): enable SUN RGBD dataset. | ||
* [utils/datasets](utils/datasets.py): enable SUN RGBD dataset. | ||
|
||
We also delete the config.py in the TokenFusion codebase since it is not used here. | ||
|
||
## Getting Started | ||
|
||
**NYUDv2 Dataset Prapare** | ||
|
||
Please follow [the data preparation instructions for NYUDv2 in TokenFusion readme](./README-TokenFusion.md#datasets). In default the data path is `/cache/datasets/nyudv2`, you may change it by `--train-dir <your data path>`. | ||
|
||
**SUN RGBD Dataset Prapare** | ||
|
||
Please download the SUN RGBD dataset follow the link in [DFormer](https://github.com/VCIP-RGBD/DFormer?tab=readme-ov-file#2--get-start).In default the data path is `/cache/datasets/sunrgbd_Dformer/SUNRGBD`, you may change it by `--train-dir <your data path>`. | ||
|
||
**NYUDv2 Training** | ||
|
||
On the NYUDv2 dataset, we follow the TokenFusion's setting, using 3 GPUs to train the GeminiFusion. | ||
|
||
```shell | ||
# b3 | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 | ||
|
||
# b5 | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b5 --dataset nyudv2 -c nyudv2_mit_b5 --dpr 0.35 | ||
|
||
# swin_tiny | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_tiny --dataset nyudv2 -c nyudv2_swin_tiny --dpr 0.2 | ||
|
||
# swin_small | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_small --dataset nyudv2 -c nyudv2_swin_small | ||
|
||
# swin_large | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large --dataset nyudv2 -c nyudv2_swin_large | ||
|
||
# swin_large_window12 | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large_window12 --dataset nyudv2 -c nyudv2_swin_large_window12 --dpr 0.2 | ||
|
||
# swin-large-384+FineTune from SUN 300eps | ||
# swin-large-384.pth.tar should be downloaded by the link below or trained by yourself | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone swin_large_window12 --dataset nyudv2 -c rerun_54.8_swin_large_window12_finetune_dpr0.15_100+200+100 \ | ||
--dpr 0.15 --num-epoch 100 200 100 --is_pretrain_finetune --resume ./swin-large-384.pth.tar | ||
``` | ||
|
||
**SUN RGBD Training** | ||
|
||
On the SUN RGBD dataset, we use 4 GPUs to train the GeminiFusion. | ||
```shell | ||
# b3 | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone mit_b3 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_mit_b3 | ||
|
||
# b5 | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone mit_b5 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_mit_b5 --weight_decay 0.05 | ||
|
||
# swin_tiny | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone swin_tiny --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_swin_tiny | ||
|
||
# swin_large_window12 | ||
CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --backbone swin_large_window12 --dataset sunrgbd --train-dir /cache/datasets/sunrgbd_Dformer/SUNRGBD -c sunrgbd_swin_large_window12 | ||
``` | ||
|
||
**Testing** | ||
|
||
To evaluate checkpoints, you need to add `--eval --resume <checkpoint path>` after the training script. | ||
|
||
For example, on the NYUDv2 dataset, the training script for GeminiFusion with mit-b3 backbone is: | ||
```shell | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 | ||
``` | ||
|
||
To evaluate the trained or downloaded checkpoint, the eval script is: | ||
```shell | ||
CUDA_VISIBLE_DEVICES=0,1,2 python -m torch.distributed.launch --nproc_per_node=3 --use_env main.py --backbone mit_b3 --dataset nyudv2 -c nyudv2_mit_b3 --eval --resume mit-b3.pth.tar | ||
``` | ||
|
||
## Model Zoo | ||
|
||
### NYUDv2 dataset | ||
|
||
| Model | backbone| mIoU | Download | | ||
|:-------:|:--------:|:-------:|:-------------------:| | ||
| GeminiFusion | MiT-B3| 56.8 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2/mit-b3.pth.tar) | | ||
| GeminiFusion | MiT-B5| 57.7 | [model]() | | ||
| GeminiFusion | swin_tiny| 52.2 | [model]() | | ||
| GeminiFusion | swin-small| 55.0 | [model]() | | ||
| GeminiFusion | swin-large-224| 58.8 | [model]() | | ||
| GeminiFusion | swin-large-384| 60.2 | [model]() | | ||
| GeminiFusion | swin-large-384 +FineTune from SUN 300eps| 60.9 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/NYUDv2/finetune-swin-large-384.pth.tar) | | ||
|
||
### SUN RGBD dataset | ||
|
||
| Model | backbone| mIoU | Download | | ||
|:-------:|:--------:|:-------:|:-------------------:| | ||
| GeminiFusion | MiT-B3| 52.7 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN/mit-b3.pth.tar) | | ||
| GeminiFusion | MiT-B5| 53.3 | [model]() | | ||
| GeminiFusion | swin_tiny| 50.2 | [model]() | | ||
| GeminiFusion | swin-large-384| 54.8 | [model](https://github.com/JiaDingCN/GeminiFusion/releases/download/SUN/swin-large-384.pth.tar) | | ||
|
||
### Citation | ||
|
||
If you find this work useful for your research, please cite our paper: | ||
|
||
<!-- ``` | ||
@misc{rukhovich2023tr3d, | ||
doi = {10.48550/ARXIV.2302.02858}, | ||
url = {https://arxiv.org/abs/2302.02858}, | ||
author = {Rukhovich, Danila and Vorontsova, Anna and Konushin, Anton}, | ||
title = {TR3D: Towards Real-Time Indoor 3D Object Detection}, | ||
publisher = {arXiv}, | ||
year = {2023} | ||
} | ||
``` --> |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.