Skip to content

Commit

Permalink
[Feature] Add PARE (#161)
Browse files Browse the repository at this point in the history
* Training and testing code for PARE: Part Attention Regressor for 3D Human Body Estimation [ICCV 2021].

* Achieving 49.35mm PA-MPJPE, 81.79 MPJPE on 3DPW, compared to the original implementation with 50.9mm PA-MPJPE, 82 MPJPE.

* Provided with detailed pre-train and training config.
  • Loading branch information
WYJSJTU authored Apr 29, 2022
1 parent 0d92745 commit 46d3dae
Show file tree
Hide file tree
Showing 33 changed files with 3,754 additions and 125 deletions.
90 changes: 90 additions & 0 deletions configs/pare/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# PARE

## Introduction

We provide the config files for PARE: [Part Attention Regressor for 3D Human Body Estimation](https://arxiv.org/abs/2104.08527).

```BibTeX
@inproceedings{Kocabas_PARE_2021,
title = {{PARE}: Part Attention Regressor for {3D} Human Body Estimation},
author = {Kocabas, Muhammed and Huang, Chun-Hao P. and Hilliges, Otmar and Black, Michael J.},
booktitle = {Proc. International Conference on Computer Vision (ICCV)},
pages = {11127--11137},
month = oct,
year = {2021},
doi = {},
month_numeric = {10}
}
```

## Notes

- [SMPL](https://smpl.is.tue.mpg.de/) v1.0 is used in our experiments.
- [J_regressor_extra.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_extra.npy?versionId=CAEQHhiBgIDD6c3V6xciIGIwZDEzYWI5NTBlOTRkODU4OTE1M2Y4YTI0NTVlZGM1)
- [J_regressor_h36m.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_h36m.npy?versionId=CAEQHhiBgIDE6c3V6xciIDdjYzE3MzQ4MmU4MzQyNmRiZDA5YTg2YTI5YWFkNjRi)
- [smpl_mean_params.npz](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/smpl_mean_params.npz?versionId=CAEQHhiBgICN6M3V6xciIDU1MzUzNjZjZGNiOTQ3OWJiZTJmNThiZmY4NmMxMTM4)
- Pascal Occluders for the pretraining:
- [pascal_occluders.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/pascal_occluders.npy?versionId=CAEQOhiBgMCH2fqigxgiIDY0YzRiNThkMjU1MzRjZTliMTBhZmFmYWY0MTViMTIx)

As for pretrained model (hrnet_w32_conv_pare_coco.pth). You can download it from [here](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/hrnet_w32_conv_pare_coco.pth?versionId=CAEQOhiBgMCxmv_RgxgiIDkxNWJhOWMxNDEyMzQ1OGQ4YTQ3NjgwNjA0MWUzNDE5) and change the path of pretrained model in the config.
You can also pretrain the model using [hrnet_w32_conv_pare_coco.py]([hrnet_w32_conv_pare_coco.py]). Download the hrnet pretrain from [here](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/hrnet_pretrain.pth?versionId=CAEQOhiBgMC26fSigxgiIGViMTFiZmJkZDljMDRhMWY4Mjc5Y2UzNzBmYzU1MGVk
) for pretrain.

Download the above resources and arrange them in the following file structure:

```text
mmhuman3d
├── mmhuman3d
├── docs
├── tests
├── tools
├── configs
└── data
├── gmm_08.pkl
├── body_models
│ ├── J_regressor_extra.npy
│ ├── J_regressor_h36m.npy
│ ├── smpl_mean_params.npz
│ └── smpl
│ ├── SMPL_FEMALE.pkl
│ ├── SMPL_MALE.pkl
│ └── SMPL_NEUTRAL.pkl
├── pretrained
│ ├── hrnet_pretrain.pth
│ └── hrnet_w32_conv_pare_coco.pth
├── preprocessed_datasets
│ ├── h36m_mosh_train.npz
│ ├── h36m_train.npz
│ ├── mpi_inf_3dhp_train.npz
│ ├── eft_mpii.npz
│ ├── eft_lspet.npz
│ ├── eft_coco_all.npz
│ ├── pw3d_test.npz
├── occluders
│ ├── pascal_occluders.npy
└── datasets
├── coco
├── h36m
├── lspet
├── mpi_inf_3dhp
├── mpii
└── pw3d
```


## Results and Models

We evaluate PARE on 3DPW. Values are MPJPE/PA-MPJPE.

Trained with MoShed Human3.6M Datasets and Cache:

| Config | 3DPW | Download |
|:------:|:-------:|:------:|
| [hrnet_w32_conv_pare_mix_cache.py](hrnet_w32_conv_pare_mix_cache.py) | 81.79 / 49.35 | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/with_mosh/hrnet_w32_conv_pare_mosh.pth?versionId=CAEQOhiBgIDooeHSgxgiIDkwYzViMTUyNjM1MjQ3ZDNiNzNjMjJlOGFlNjgxYjlh) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/with_mosh/20220427_113717.log?versionId=CAEQOhiBgMClqr3PgxgiIGRjZWU0NzFhMmVkMDQzN2I5ZmY5Y2MxMzJiZDM3MGQ0) |


Trained without MoShed Human3.6M Datasets:
| Config | 3DPW | Download |
|:------:|:-------:|:------:|
| [hrnet_w32_conv_pare_mix_no_mosh.py](hrnet_w32_conv_pare_mix_no_mosh.py) | 81.81 / 50.78 | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/without_mosh/hrnet_w32_conv_pare.pth?versionId=CAEQOhiBgMCi4YbVgxgiIDgzYzFhMWNlNDE2NTQwN2ZiOTQ1ZGJmYTM4OTNmYWY5) | [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/without_mosh/20220427_113844.log?versionId=CAEQOhiBgMCHwcTPgxgiIGI0NjI0M2JiM2ViMzRhMTFiMWQxZDJmMGI5MmQwMjgw) |
207 changes: 207 additions & 0 deletions configs/pare/hrnet_w32_conv_pare_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
use_adversarial_train = True

# evaluate
evaluation = dict(interval=10, metric=['pa-mpjpe', 'mpjpe'])
# optimizer

optimizer = dict(
backbone=dict(type='Adam', lr=2.0e-4),
head=dict(type='Adam', lr=2.0e-4),
)
optimizer_config = dict(grad_clip=None)

lr_config = dict(policy='Fixed', by_epoch=False)
runner = dict(type='EpochBasedRunner', max_epochs=200)

log_config = dict(
interval=50, hooks=[
dict(type='TextLoggerHook'),
])

_base_ = ['../_base_/default_runtime.py']
checkpoint_config = dict(interval=10)
width = 32
downsample = False
use_conv = True
hrnet_extra = dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(width, width * 2)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(width, width * 2, width * 4)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(width, width * 2, width * 4, width * 8)),
downsample=downsample,
use_conv=use_conv,
pretrained_layers=[
'conv1',
'bn1',
'conv2',
'bn2',
'layer1',
'transition1',
'stage2',
'transition2',
'stage3',
'transition3',
'stage4',
],
final_conv_kernel=1,
return_list=False,
)

find_unused_parameters = True

model = dict(
type='ImageBodyModelEstimator',
backbone=dict(
type='PoseHighResolutionNet',
extra=hrnet_extra,
num_joints=24,
init_cfg=dict(
type='Pretrained',
checkpoint='data/pretrained_models/hrnet_pretrain.pth')),
head=dict(
type='PareHead',
num_joints=24,
num_input_features=480,
smpl_mean_params='data/body_models/smpl_mean_params.npz',
num_deconv_layers=2,
num_deconv_filters=[128] *
2, # num_deconv_filters = [num_deconv_filters] * num_deconv_layers
num_deconv_kernels=[4] *
2, # num_deconv_kernels = [num_deconv_kernels] * num_deconv_layers
use_heatmaps='part_segm',
use_keypoint_attention=True,
backbone='hrnet_w32-conv',
),
body_model_train=dict(
type='SMPL',
keypoint_src='smpl_54',
keypoint_dst='smpl_49',
model_path='data/body_models/smpl',
keypoint_approximate=True,
extra_joints_regressor='data/body_models/J_regressor_extra.npy'),
body_model_test=dict(
type='SMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
convention='smpl_49',
loss_keypoints3d=dict(type='MSELoss', loss_weight=300),
loss_keypoints2d=dict(type='MSELoss', loss_weight=300),
loss_smpl_pose=dict(type='MSELoss', loss_weight=60),
loss_smpl_betas=dict(type='MSELoss', loss_weight=60 * 0.001),
loss_segm_mask=dict(type='CrossEntropyLoss', loss_weight=60),
loss_camera=dict(type='CameraPriorLoss', loss_weight=1),
)

# dataset settings
dataset_type = 'HumanImageDataset'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data_keys = [
'has_smpl', 'has_keypoints3d', 'has_keypoints2d', 'smpl_body_pose',
'smpl_global_orient', 'smpl_betas', 'smpl_transl', 'keypoints2d',
'keypoints3d', 'sample_idx'
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomChannelNoise', noise_factor=0.4),
dict(
type='SyntheticOcclusion',
occluders_file='data/occluders/pascal_occluders.npy'),
dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='smpl_49'),
dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25),
dict(type='MeshAffine', img_res=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=data_keys),
dict(
type='Collect',
keys=['img', *data_keys],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0),
dict(type='MeshAffine', img_res=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=data_keys),
dict(
type='Collect',
keys=['img', *data_keys],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]

inference_pipeline = [
dict(type='MeshAffine', img_res=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(
type='Collect',
keys=['img', 'sample_idx'],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]

data = dict(
samples_per_gpu=64,
workers_per_gpu=0,
train=dict(
type='MixedDataset',
configs=[
dict(
type=dataset_type,
dataset_name='coco',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_49',
ann_file='eft_coco_all.npz'),
],
partition=[1.0],
),
test=dict(
type=dataset_type,
body_model=dict(
type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
data_prefix='data',
pipeline=test_pipeline,
ann_file='pw3d_test.npz'),
val=dict(
type=dataset_type,
body_model=dict(
type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
data_prefix='data',
pipeline=test_pipeline,
ann_file='pw3d_test.npz'),
)
Loading

0 comments on commit 46d3dae

Please sign in to comment.