This is an code implementation of CVPR2020 paper (What Deep CNNs Benefit from Global Covariance Pooling: An Optimization Perspective(poster)), created by Qilong Wang and Li Zhang.
Recent works have demonstrated that global covariance pooling (GCP) has the ability to improve performance of deep convolutional neural networks (CNNs) on visual classification task. Despite considerable advance, the reasons on effectiveness of GCP on deep CNNs have not been well studied. In this paper, we make an attempt to understand what deep CNNs benefit from GCP in a viewpoint of optimization. Specifically, we explore the effect of GCP on deep CNNs in terms of the Lipschitzness of optimization loss and the predictiveness of gradients, and show that GCP can make the optimization landscape more smooth and the gradients more predictive. Furthermore, we discuss the connection between GCP and second-order optimization for deep CNNs. More importantly, above findings can account for several merits of covariance pooling for training deep CNNs that have not been recognized previously or fully explored, including significant acceleration of network convergence (i.e., the networks trained with GCP can support rapid decay of learning rates, achieving favorable performance while significantly reducing number of training epochs), stronger robustness to distorted examples generated by image corruptions and perturbations, and good generalization ability to different vision tasks, e.g., object detection and instance segmentation. We conduct extensive experiments using various deep CNN models on diversified tasks, and the results provide strong support to our findings.
@inproceedings{wang2020deep,
title={What Deep CNNs Benefit from Global Covariance Pooling: An Optimization Perspective},
author={Wang, Qilong and Zhang, Li and Wu, Banggu and Ren, Dongwei and Li, Peihua and Zuo, Wangmeng and Hu, Qinghua},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
year={2020}
}
- OS: Ubuntu 16.04
- CUDA: 9.0/10.0
- Toolkit: PyTorch 1.3/1.4
- GPU: GTX 2080Ti/TiTan XP
conda create -n gcp-optimization python=3.7
conda activate gcp-optimization
b. Install PyTorch and torchvision following the official instructions, e.g.,
conda install pytorch torchvision -c pytorch
git clone https://github.com/ZhangLi-CS/GCP_Optimization.git
cd GCP_Optimization
-
Test MobileNetV2 models' LandScape: In the floder
./landscape/MobileNetV2
, runsh ./scripts/train.sh
-
Test ResNet models' LandScape: In the floder
./landscape/ResNet
, runsh ./scripts/train.sh
*Note that you need to modify the dataset path
or model name
in train.sh
for fitting your configurations, and descriptions on all parameters can be found in file ./landscape/readme.txt
.
LRnorm: The settings of lr in each original paper.
LRfast: The settings of lr have the fastest convergence.
LRadju: The settings of lr make the best trade-off between convergence speed and classification accuracy.
-
Training models (except ShuffleNetV2) on ImageNet : In floder
src
, runsh ./scripts/train/train.sh
-
Training ShuffleNetV2 models on ImageNet : In floder
src
, runsh ./scripts/train/train_shufflenet.sh
-
Testing models on ImageNet : In floder
src
, runsh ./scripts/val/val.sh
*Note that you need to modify the dataset path
or model name
in train.sh
or val.sh
for fitting your configurations, and descriptions on all parameters can be found in file ./src/scripts/readme.txt
.
Models | Top-1 acc.(%) | Top-5 acc.(%) | BaiduDrive(models) | Extract code | GoogleDrive |
---|---|---|---|---|---|
MobileNetV2_GAP_LRnorm | 71.62 | 90.18 | MobileNetV2_GAP_LRnorm | va25 | MobileNetV2_GAP_LRnorm |
MobileNetV2_GAP_LRfast | 69.29 | 89.01 | MobileNetV2_GAP_LRfast | f5n4 | MobileNetV2_GAP_LRfast |
MobileNetV2_GAP_LRadju | 71.27 | 90.08 | MobileNetV2_GAP_LRadju | pai8 | MobileNetV2_GAP_LRadju |
MobileNetV2_GCP_LRnorm | 74.39 | 91.86 | MobileNetV2_GCP_LRnorm | 3r9q | MobileNetV2_GCP_LRnorm |
MobileNetV2_GCP_LRfast | 72.45 | 90.51 | MobileNetV2_GCP_LRfast | dq5w | MobileNetV2_GCP_LRfast |
MobileNetV2_GCP_LRadju | 73.97 | 91.53 | MobileNetV2_GCP_LRadju | i7y3 | MobileNetV2_GCP_LRadju |
MobileNetV2_GCP_LRnorm_128 | 73.28 | 91.26 | MobileNetV2_GCP_LRnorm_128 | cxhu | MobileNetV2_GCP_LRnorm_128 |
MobileNetV2_GCP_LRadju_128 | 72.58 | 90.87 | MobileNetV2_GCP_LRadju_128 | 3qdx | MobileNetV2_GCP_LRadju_128 |
Models | Top-1 acc.(%) | Top-5 acc.(%) | BaiduDrive(models) | Extract code | GoogleDrive |
---|---|---|---|---|---|
ShuffleNetV2_GAP_LRnorm | 67.96 | 87.84 | ShuffleNetV2_GAP_LRnorm | rbew | ShuffleNetV2_GAP_LRnorm |
ShuffleNetV2_GAP_LRfast | 66.13 | 86.54 | ShuffleNetV2_GAP_LRfast | 7y58 | ShuffleNetV2_GAP_LRfast |
ShuffleNetV2_GAP_LRadju | 67.15 | 87.35 | ShuffleNetV2_GAP_LRadju | c7i8 | ShuffleNetV2_GAP_LRadju |
ShuffleNetV2_GCP_LRnorm | 71.83 | 90.04 | ShuffleNetV2_GCP_LRnorm | tr5u | ShuffleNetV2_GCP_LRnorm |
ShuffleNetV2_GCP_LRfast | 70.29 | 89.00 | ShuffleNetV2_GCP_LRfast | 8qxx | ShuffleNetV2_GCP_LRfast |
ShuffleNetV2_GCP_LRadju | 71.17 | 89.74 | ShuffleNetV2_GCP_LRadju | nud1 | ShuffleNetV2_GCP_LRadju |
Models | Top-1 acc.(%) | Top-5 acc.(%) | BaiduDrive(models) | Extract code | GoogleDrive |
---|---|---|---|---|---|
ResNet18_GAP_LRnorm | 70.47 | 89.59 | ResNet18_GAP_LRnorm | z7ab | ResNet18_GAP_LRnorm |
ResNet18_GAP_LRfast | 66.02 | 86.69 | ResNet18_GAP_LRfast | 78uw | ResNet18_GAP_LRfast |
ResNet18_GAP_LRadju | 69.62 | 89.00 | ResNet18_GAP_LRadju | 29cn | ResNet18_GAP_LRadju |
ResNet18_GCP_LRnorm | 75.48 | 92.23 | ResNet18_GCP_LRnorm | eje8 | ResNet18_GCP_LRnorm |
ResNet18_GCP_LRfast | 72.02 | 89.97 | ResNet18_GCP_LRfast | k6f6 | ResNet18_GCP_LRfast |
ResNet18_GCP_LRadju | 74.86 | 91.81 | ResNet18_GCP_LRadju | tci6 | ResNet18_GCP_LRadju |
Models | Top-1 acc.(%) | Top-5 acc.(%) | BaiduDrive(models) | Extract code | GoogleDrive |
---|---|---|---|---|---|
ResNet34_GAP_LRnorm | 74.19 | 91.60 | ResNet34_GAP_LRnorm | 1yp6 | ResNet34_GAP_LRnorm |
ResNet34_GAP_LRfast | 69.88 | 89.25 | ResNet34_GAP_LRfast | nmni | ResNet34_GAP_LRfast |
ResNet34_GAP_LRadju | 73.13 | 91.14 | ResNet34_GAP_LRadju | 5eyk | ResNet34_GAP_LRadju |
ResNet34_GCP_LRnorm | 77.11 | 93.33 | ResNet34_GCP_LRnorm | bn2d | ResNet34_GCP_LRnorm |
ResNet34_GCP_LRfast | 73.88 | 91.42 | ResNet34_GCP_LRfast | wmky | ResNet34_GCP_LRfast |
ResNet34_GCP_LRadju | 76.81 | 93.04 | ResNet34_GCP_LRadju | 4w21 | ResNet34_GCP_LRadju |
Models | Top-1 acc.(%) | Top-5 acc.(%) | BaiduDrive(models) | Extract code | GoogleDrive |
---|---|---|---|---|---|
ResNet50_GAP_LRnorm | 76.02 | 92.97 | ResNet50_GAP_LRnorm | 3r9p | ResNet50_GAP_LRnorm |
ResNet50_GAP_LRfast | 71.08 | 90.04 | ResNet50_GAP_LRfast | reub | ResNet50_GAP_LRfast |
ResNet50_GAP_LRadju | 75.32 | 92.47 | ResNet50_GAP_LRadju | 5tdw | ResNet50_GAP_LRadju |
ResNet50_GCP_LRnorm | 78.56 | 94.09 | ResNet50_GCP_LRnorm | e7iy | ResNet50_GCP_LRnorm |
ResNet50_GCP_LRfast | 75.31 | 92.11 | ResNet50_GCP_LRfast | 3j8e | ResNet50_GCP_LRfast |
ResNet50_GCP_LRadju | 78.03 | 93.95 | ResNet50_GCP_LRadju | n4vq | ResNet50_GCP_LRadju |
ResNet50_GCP_LRnorm_128 | 78.02 | 94.02 | ResNet50_GCP_LRnorm_128 | 976a | ResNet50_GCP_LRnorm_128 |
ResNet50_GCP_LRadju_128 | 77.72 | 93.73 | ResNet50_GCP_LRadju_128 | xf7g | ResNet50_GCP_LRadju_128 |
Models | Top-1 acc.(%) | Top-5 acc.(%) | BaiduDrive(models) | Extract code | GoogleDrive |
---|---|---|---|---|---|
ResNet101_GAP_LRnorm | 77.67 | 93.83 | ResNet101_GAP_LRnorm | 3u9a | ResNet101_GAP_LRnorm |
ResNet101_GAP_LRfast | 73.13 | 91.06 | ResNet101_GAP_LRfast | 11g2 | ResNet101_GAP_LRfast |
ResNet101_GAP_LRadju | 77.53 | 93.53 | ResNet101_GAP_LRadju | nikb | ResNet101_GAP_LRadju |
ResNet101_GCP_LRnorm | 79.47 | 94.71 | ResNet101_GCP_LRnorm | kr4s | ResNet101_GCP_LRnorm |
ResNet101_GCP_LRfast | 76.38 | 92.82 | ResNet101_GCP_LRfast | iy77 | ResNet101_GCP_LRfast |
ResNet101_GCP_LRadju | 79.18 | 94.47 | ResNet101_GCP_LRadju | ytfb | ResNet101_GCP_LRadju |
*If you would like to evaluate above pre-trained models, please do the following:
-
Download the pre-trained models.
-
Testing on ImageNet: In floder
src
, runsh ./scripts/val/val_download.sh
The datesets and evaluation code can be downloaded from https://github.com/hendrycks/robustness
Method | IMAGENET-C | IMAGENET-P | ||
---|---|---|---|---|
mCE | Relative mCE | mFP | mT5D | |
MobileNetV2_GAP_LRnorm | 86.4 | 113.6 | 79.1 | 96.0 |
MobileNetV2_GCP_LRnorm | 81.7 | 110.6 | 64.3 | 87.6 |
ShuffleNetV2_GAP_LRnorm | 92.7 | 126.7 | 94.7 | 108.2 |
ShuffleNetV2_GCP_LRnorm | 85.2 | 112.6 | 75.2 | 95.5 |
ResNet18+ | 84.7 | 103.9 | 72.8 | 87.0 |
ResNet18_GCP_LRnorm | 76.3 | 101.3 | 53.2 | 77.1 |
ResNet34+ | 77.9 | 98.7 | 61.7 | 79.5 |
ResNet34_GCP_LRnorm | 72.4 | 96.9 | 47.7 | 72.4 |
ResNet50+ | 76.7 | 105.0 | 58.0 | 78.3 |
ResNet50_GCP_LRnorm | 70.7 | 97.9 | 47.5 | 74.6 |
ResNet101+ | 70.3 | 93.7 | 52.6 | 73.9 |
ResNet101_GCP_LRnorm | 65.5 | 89.1 | 42.1 | 68.3 |
We use the mmdetection to train/evaluate our models on Object Detection and Instance Segmentation.
Backbone Model | Method | AP | AP50 | AP75 | APS | APM | APL | BaiduDrive(models) | Extract code | GoogleDrive |
---|---|---|---|---|---|---|---|---|---|---|
ResNet50 | GAP | 36.4 | 58.2 | 39.2 | 21.8 | 40.0 | 46.2 | ResNet50_GAP | nohm | ResNet50_GAP |
GCPD | 36.6 | 58.4 | 39.5 | 21.3 | 40.8 | 47.0 | ResNet50_GCP_D | ct92 | ResNet50_GCP_D | |
GCPM | 37.1 | 59.1 | 39.9 | 22.0 | 40.9 | 47.6 | ResNet50_GCP_M | 4coy | ResNet50_GCP_M | |
ResNet101 | GAP | 38.7 | 60.6 | 41.9 | 22.7 | 43.2 | 50.4 | ResNet101_GAP | c4op | ResNet101_GAP |
GCPD | 39.5 | 60.7 | 43.1 | 22.9 | 44.1 | 51.4 | ResNet101_GCP_D | f1nb | ResNet101_GCP_D | |
GCPM | 39.6 | 61.2 | 43.1 | 23.3 | 43.9 | 51.3 | ResNet101_GCP_M | 2jek | ResNet101_GCP_M |
Backbone Model | Method | AP | AP50 | AP75 | APS | APM | APL | BaiduDrive(models) | Extract code | GoogleDrive |
---|---|---|---|---|---|---|---|---|---|---|
ResNet50 | GAP | 37.2 | 58.9 | 40.3 | 22.2 | 40.7 | 48.0 | ResNet50_GAP | wg4y | ResNet50_GAP |
GCPD | 37.3 | 58.8 | 40.4 | 22.0 | 41.1 | 48.2 | ResNet50_GCP_D | b7fw | ResNet50_GCP_D | |
GCPM | 37.9 | 59.4 | 41.3 | 22.4 | 41.5 | 49.0 | ResNet50_GCP_M | 3p37 | ResNet50_GCP_M | |
ResNet101 | GAP | 39.4 | 60.9 | 43.3 | 23.0 | 43.7 | 51.4 | |||
GCPD | 40.3 | 61.5 | 44.0 | 24.1 | 44.7 | 52.5 | ResNet101_GCP_D | fdhq | ResNet101_GCP_D | |
GCPM | 40.7 | 62.0 | 44.6 | 23.9 | 45.2 | 52.9 | ResNet101_GCP_M | 0fbl | ResNet101_GCP_M |
Backbone Model | Method | AP | AP50 | AP75 | APS | APM | APL |
---|---|---|---|---|---|---|---|
ResNet50 | GAP | 34.1 | 55.5 | 36.2 | 16.1 | 36.7 | 50.0 |
GCPD | 34.2 | 55.3 | 36.4 | 15.8 | 37.1 | 50.1 | |
GCPM | 34.7 | 56.3 | 36.8 | 16.4 | 37.5 | 50.6 | |
ResNet101 | GAP | 35.9 | 57.7 | 38.4 | 16.8 | 39.1 | 53.6 |
GCPD | 36.5 | 58.2 | 38.9 | 17.3 | 39.9 | 53.5 | |
GCPM | 36.7 | 58.7 | 39.1 | 17.6 | 39.9 | 53.7 |
We would like to thank the team behind the iSQRT-COV, ImageNet-C & ImageNet-P and mmdetection for providing a nice code, and our code is based on it.
If you have any questions or suggestions, please feel free to contact us: [email protected]; [email protected].