Kaipeng Zhang2,† Fangrui Lv5, Kurt Keutzer3, Xiangyu Yue1†
2 OpenGVLab,Shanghai AI Laboratory 3 UC Berkeley 4 Tencent AI Lab 5 Tsinghua University
* Equal Contribution † Corresponding Authors
- 🚀 UniDG is an effective Test-Time Adaptation scheme. It brings out an average improvement to existing DG methods by
+5.0%
Accuracy on DomainBed benchmarks including PACS, VLCS, OfficeHome, and TerraInc datasets.
- 🚀 UniDG is architecture-agnostic. Unified with 10+ visual backbones including CNN, MLP, and transformer-based models, UniDG can bring out consistent performance enhancement of
+5.4%
average on domain generalization.
- 🏆 Achieved
79.6 mAcc
on PACS , VLCS , OfficeHome, TerraIncognita , and DomainNet datasets.
We propose UniDG, a novel and Unified framework for Domain Generalization that is capable of significantly enhancing the out-of-distribution generalization performance of foundation models regardless of their architectures. The core idea of UniDG is to finetune models during the inference stage, which saves the cost of iterative training. Specifically, we encourage models to learn the distribution of test data in an unsupervised manner and impose a penalty regarding the updating step of model parameters. The penalty term can effectively reduce the catastrophic forgetting issue as we would like to maximally preserve the valuable knowledge in the original model. Empirically, across 12 visual backbones, including CNN-, MLP-, and Transformer-based models, ranging from 1.89M to 303M parameters, UniDG shows an average accuracy improvement of +5.4% on DomainBed.
-
🌟 Extensibility: we intergrate UniDG with Domainbed. More networks and algorithms can be built easily with our framework, and UniDG brings out an average improvement of
+5.0%
to existing methods including ERM, CORAL, and MIRO. -
🌟 Reproducibility: all implemented models are trained on various tasks at least three times. Mean±std is provided in the UniDG paper. Pretrained models and logs are available.
-
🌟 Ease of Use: we develop tools to charge experimental logs with json files, which can transform results directly into latex:
-
🌟 Visualization Tools: we provides scripts to easily visualize results by T-SNE and performance curves:
- Convergence Curves:
- T-SNE Visualization Results:
We provide pretrained checkpoints with base ERM
algorithm to reproduce our experimental results conveniently.
Note that IID_best.pkl
is the pretrained source model.
- CORAL Source Models
Backbone | Dataset | Algorithm | Base Model | Adaptation | Google Drive |
---|---|---|---|---|---|
ResNet-50 | VLCS | PACS | OfficeHome | TerraInc| DomainNet | CORAL | 64.1 ± 0.1 | 69.3 ± 0.2 | ckpt |
Swin Transformer | VLCS | PACS | OfficeHome | TerraInc | CORAL | 77.2 ± 0.1 | 82.5 ± 0.2 | ckpt |
ConvNeXt | VLCS | PACS | OfficeHome | TerraInc| DomainNet | CORAL | 75.1 ± 0.1 | 79.6 ± 0.3 | ckpt |
- ERM Source Models
Backbone | Dataset | Algorithm | Base Model | Adaptation | Google Drive |
---|---|---|---|---|---|
ResNet-18 | VLCS | PACS | OfficeHome | TerraInc | ERM | 63.0 ± 0.0 | 67.2 ± 0.2 | ckpt |
ResNet-50 | VLCS | PACS | OfficeHome | TerraInc | ERM | 67.6 ± 0.0 | 73.1 ± 0.2 | ckpt |
ResNet-101 | VLCS | PACS | OfficeHome | TerraInc | ERM | 68.1 ± 0.1 | 72.3 ± 0.3 | ckpt |
Mobilenet V3 | VLCS | PACS | OfficeHome | TerraInc | ERM | 58.9 ± 0.0 | 65.3 ± 0.2 | ckpt |
EfficientNet V2 | VLCS | PACS | OfficeHome | TerraInc | ERM | 67.2 ± 0.0 | 72.1 ± 0.3 | ckpt |
ConvNeXt-B | VLCS | PACS | OfficeHome | TerraInc | ERM | 79.7 ± 0.0 | 83.7 ± 0.1 | ckpt |
ViT-B16 | VLCS | PACS | OfficeHome | TerraInc | ERM | 69.5 ± 0.0 | 75.4 ± 0.2 | ckpt |
ViT-L16 | VLCS | PACS | OfficeHome | TerraInc | ERM | 74.1 ± 0.0 | 79.9 ± 0.3 | ckpt |
DeiT | VLCS | PACS | OfficeHome | TerraInc | ERM | 73.5 ± 0.0 | 77.8 ± 0.2 | ckpt |
Swin Transformer | VLCS | PACS | OfficeHome | TerraInc | ERM | 77.2 ± 0.0 | 81.5 ± 0.3 | ckpt |
Mixer-B16 | VLCS | PACS | OfficeHome | TerraInc | ERM | 57.2 ± 0.1 | 65.6 ± 0.3 | ckpt |
Mixer-L16 | VLCS | PACS | OfficeHome | TerraInc | ERM | 67.4 ± 0.0 | 73.0 ± 0.2 | ckpt |
Environments Set up
git clone https://github.com/invictus717/UniDG.git && cd UniDG
conda env create -f UniDG.yaml && conda activate UniDG
Datasets Preparation
python -m domainbed.scripts.download \
--data_dir=./data
Train a model:
python -m domainbed.scripts.train\
--data_dir=./data \
--algorithm ERM \
--dataset OfficeHome \
--test_env 2 \
--hparams "{\"backbone\": \"resnet50\"}" \
--output_dir my/pretrain/ERM/resnet50
Note that you can download our pretrained checkpoints in the Model Zoo.
Then you can perform self-supervised adaptation:
python -m domainbed.scripts.unsupervised_adaptation \
--input_dir my/pretrain/ERM/resnet50 \
--adapt_algorithm=UniDG
Then you can perform self-supervised adaptation:
python -m domainbed.scripts.collect_all_results\
--input_dir=my/pretrain/ERM \
--adapt_dir=results/ERM/resnet50 \
--output_dir=log/UniDG/ \
--adapt_algorithm=UniDG \
--latex
For T-SNE visualization:
python -m domainbed.scripts.visualize_tsne\
--input_dir=my/pretrain/ERM \
--adapt_dir=UniDG/results/ERM/resnet50 \
--output_dir=log/UniDG/ \
--adapt_algorithm=UniDG \
--latex
For performance curves visualization:
python -m domainbed.scripts.visualize_curves\
--input_dir=my/pretrain/ERM \
--adapt_dir=UniDG/results/ERM/resnet50 \
--output_dir=log/UniDG/ \
--adapt_algorithm=UniDG \
--latex
If this work is helpful for your research, please consider citing the following BibTeX entry.
@article{zhang2023unified,
title={Towards Unified and Effective Domain Generalization},
author={Yiyuan Zhang and Kaixiong Gong and Xiaohan Ding and Kaipeng Zhang and Fangrui Lv and Kurt Keutzer and Xiangyu Yue},
year={2023},
eprint={2310.10008},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
This repository is based on DomainBed, T3A, timm. Thanks a lot for their great works!
This repository is released under the Apache 2.0 license as found in the LICENSE file.