Skip to content

Commit

Permalink
[Add] CLIFF (#302)
Browse files Browse the repository at this point in the history
* add cliff head

* add function to convert from crop to full camera

* add cliff annotation datasets converter

* add tramsforms to get bbox information

* store crop trans

* cliff mesh estimator

* modification to take in different resolutions

* add configs

* add missing comma

* format correction

* isort formating

* correct error in cliff_head

* revert unnecessary changes in cliff_head

* add configs(single dataset) and small modification

* configs format modification

* add test for cliff head

* format correction

* update test file

* format correction

* update test file

* format correction

* update test file

* format correction

* docformatter correction

* update test file

* format

* add README

* add README

* add test for cliff data converter

* add test for cliff mesh estimator

* update tests

* merge cliff mesh estimator to mesh estimator

* revert unnecessary tests

* format

* Revert to CliffMeshEstimator

* Fix wrong class name in test

* Fix linter

* Fix bugs for test architecture

* Fix test_data_converters.py

* Update download links

* Update pytorch3d install in workflow

* Format

* Add additional tests

* Update to ubuntu-20.04

* Update to ubuntu-20.04

* Fix pickle

* Fix setup.cfg

* Fix setup.cfg

* Change pickle5 to pickle

* Fix pandas version

---------

Co-authored-by: caizhongang <[email protected]>
Co-authored-by: caizhongang <[email protected]>
  • Loading branch information
3 people authored Apr 5, 2023
1 parent 2e54142 commit 7996ee5
Show file tree
Hide file tree
Showing 22 changed files with 2,402 additions and 36 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ concurrency:

jobs:
build_cuda101:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [3.8]
Expand Down Expand Up @@ -69,7 +69,7 @@ jobs:
- name: Install pytorch3d
run: |
conda install -c fvcore -c iopath -c conda-forge fvcore iopath -y
conda install pytorch3d -c pytorch3d
pip install "git+https://github.com/facebookresearch/pytorch3d.git"
- name: Install MMCV
run: |
pip install "mmcv-full>=1.3.17,<=1.5.3" -f https://download.openmmlab.com/mmcv/dist/cpu/torch${{matrix.torch}}/index.html
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ concurrency:

jobs:
lint:
runs-on: ubuntu-18.04
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
Expand Down
81 changes: 81 additions & 0 deletions configs/cliff/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# CLIFF

## Introduction

We provide the config files for CLIFF: [CLIFF: Carrying Location Information in Full Frames into Human Pose and Shape Estimation](https://arxiv.org/pdf/2208.00571.pdf).

```BibTeX
@Inproceedings{li2022cliff,
author = {Li, Zhihao and
Liu, Jianzhuang and
Zhang, Zhensong and
Xu, Songcen and
Yan, Youliang},
title = {CLIFF: Carrying Location Information in Full Frames into Human Pose and Shape Estimation},
booktitle = {ECCV},
year = {2022}
}
```

## Notes

- [SMPL](https://smpl.is.tue.mpg.de/) v1.0 is used in our experiments.
- [J_regressor_extra.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_extra.npy?versionId=CAEQHhiBgIDD6c3V6xciIGIwZDEzYWI5NTBlOTRkODU4OTE1M2Y4YTI0NTVlZGM1)
- [J_regressor_h36m.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/J_regressor_h36m.npy?versionId=CAEQHhiBgIDE6c3V6xciIDdjYzE3MzQ4MmU4MzQyNmRiZDA5YTg2YTI5YWFkNjRi)
- [pascal_occluders.npy](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/pare/pascal_occluders.npy?versionId=CAEQOhiBgMCH2fqigxgiIDY0YzRiNThkMjU1MzRjZTliMTBhZmFmYWY0MTViMTIx)
- [resnet50_a1h2_176-001a1197.pth](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth)
- [resnet50_a1h2_176-001a1197.pth(alternative download link)](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_a1h2_176-001a1197.pth)

Download the above resources and arrange them in the following file structure:

```text
mmhuman3d
├── mmhuman3d
├── docs
├── tests
├── tools
├── configs
└── data
├── checkpoints
│ ├── resnet50_a1h2_176-001a1197.pth
├── body_models
│ ├── J_regressor_extra.npy
│ ├── J_regressor_h36m.npy
│ ├── smpl_mean_params.npz
│ └── smpl
│ ├── SMPL_FEMALE.pkl
│ ├── SMPL_MALE.pkl
│ └── SMPL_NEUTRAL.pkl
├── preprocessed_datasets
│ ├── cliff_coco_train.npz
│ ├── cliff_mpii_train.npz
│ ├── h36m_mosh_train.npz
│ ├── muco3dhp_train.npz
│ ├── mpi_inf_3dhp_train.npz
│ └── pw3d_test.npz
├── occluders
│ ├── pascal_occluders.npy
└── datasets
├── coco
├── h36m
├── muco
├── mpi_inf_3dhp
├── mpii
└── pw3d
```

## Training
Stage 1: First use [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) to train.

Stage 2: After around 150 epoches, switch to [resume.py](resume.py) by using "--resume-from" optional argument.

## Results and Models

We evaluate HMR on 3DPW. Values are MPJPE/PA-MPJPE.

| Config | 3DPW | Download |
|:---------------------------------------------------------:|:-------------:|:------:|
| Stage 1: [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) | 48.65 / 76.49 | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_cliff-8328e2e2_20230327.pth) &#124; [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/20220909_142945.log)
| Stage 2: [resnet50_pw3d_cache.py](resnet50_pw3d_cache.py) | 47.38 / 75.08 | [model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/resnet50_cliff_new-1e639f1d_20230327.pth) &#124; [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmhuman3d/models/cliff/20230222_092227.log)
189 changes: 189 additions & 0 deletions configs/cliff/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
_base_ = ['../_base_/default_runtime.py']
use_adversarial_train = True

# evaluate
evaluation = dict(metric=['pa-mpjpe', 'mpjpe'])
# optimizer
optimizer = dict(
backbone=dict(type='Adam', lr=1e-4),
head=dict(type='Adam', lr=1e-4),
# disc=dict(type='Adam', lr=1e-4)
)
optimizer_config = dict(grad_clip=2.0)
# learning policy
lr_config = dict(policy='Fixed', by_epoch=False)
runner = dict(type='EpochBasedRunner', max_epochs=800)

log_config = dict(
interval=50,
hooks=[
dict(type='TextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])

img_resolution = (192, 256)

# model settings
model = dict(
type='CliffImageBodyModelEstimator',
backbone=dict(
type='ResNet',
depth=50,
out_indices=[3],
norm_eval=False,
norm_cfg=dict(type='SyncBN', requires_grad=True),
init_cfg=dict(
type='Pretrained',
checkpoint='data/checkpoints/resnet50_a1h2_176-001a1197.pth')),
head=dict(
type='CliffHead',
feat_dim=2048,
smpl_mean_params='data/body_models/smpl_mean_params.npz'),
body_model_train=dict(
type='SMPL',
keypoint_src='smpl_54',
keypoint_dst='smpl_54',
model_path='data/body_models/smpl',
keypoint_approximate=True,
extra_joints_regressor='data/body_models/J_regressor_extra.npy'),
body_model_test=dict(
type='SMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
convention='smpl_54',
loss_keypoints3d=dict(type='SmoothL1Loss', loss_weight=100),
loss_keypoints2d=dict(type='SmoothL1Loss', loss_weight=10),
loss_vertex=dict(type='L1Loss', loss_weight=2),
loss_smpl_pose=dict(type='MSELoss', loss_weight=3),
loss_smpl_betas=dict(type='MSELoss', loss_weight=0.02),
loss_adv=dict(
type='GANLoss',
gan_type='lsgan',
real_label_val=1.0,
fake_label_val=0.0,
loss_weight=1),
# disc=dict(type='SMPLDiscriminator')
)
# dataset settings
dataset_type = 'HumanImageDataset'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
data_keys = [
'has_smpl',
'smpl_body_pose',
'smpl_global_orient',
'smpl_betas',
'smpl_transl',
'keypoints2d',
'keypoints3d',
'sample_idx',
'img_h', # extras for cliff
'img_w',
'focal_length',
'center',
'scale',
'bbox_info',
'crop_trans',
'inv_trans'
]
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomChannelNoise', noise_factor=0.4),
dict(type='RandomHorizontalFlip', flip_prob=0.5, convention='smpl_54'),
dict(type='GetRandomScaleRotation', rot_factor=30, scale_factor=0.25),
dict(type='GetBboxInfo'),
dict(type='MeshAffine', img_res=img_resolution),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=data_keys),
dict(
type='Collect',
keys=['img', *data_keys],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]
adv_data_keys = [
'smpl_body_pose', 'smpl_global_orient', 'smpl_betas', 'smpl_transl'
]
train_adv_pipeline = [dict(type='Collect', keys=adv_data_keys, meta_keys=[])]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='GetRandomScaleRotation', rot_factor=0, scale_factor=0),
dict(type='GetBboxInfo'),
dict(type='MeshAffine', img_res=img_resolution),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=data_keys),
dict(
type='Collect',
keys=['img', *data_keys],
meta_keys=[
'image_path', 'center', 'scale', 'rotation', 'img_h', 'img_w',
'bbox_info'
])
]

inference_pipeline = [
dict(type='MeshAffine', img_res=img_resolution),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(
type='Collect',
keys=['img', 'sample_idx'],
meta_keys=['image_path', 'center', 'scale', 'rotation'])
]

cache_files = {
'cliff_coco': 'data/cache/cliff_coco_train_smpl_54.npz',
}
data = dict(
samples_per_gpu=64,
workers_per_gpu=2,
train=dict(
type='AdversarialDataset',
train_dataset=dict(
type='MixedDataset',
configs=[
dict(
type=dataset_type,
dataset_name='coco',
data_prefix='data',
pipeline=train_pipeline,
convention='smpl_54',
cache_data_path=cache_files['cliff_coco'],
ann_file='cliff_coco_train.npz'),
],
partition=[1.0],
),
adv_dataset=dict(
type='MeshDataset',
dataset_name='cmu_mosh',
data_prefix='data',
pipeline=train_adv_pipeline,
ann_file='cmu_mosh.npz')),
val=dict(
type=dataset_type,
body_model=dict(
type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
data_prefix='data',
pipeline=test_pipeline,
ann_file='pw3d_test.npz'),
test=dict(
type=dataset_type,
body_model=dict(
type='GenderedSMPL',
keypoint_src='h36m',
keypoint_dst='h36m',
model_path='data/body_models/smpl',
joints_regressor='data/body_models/J_regressor_h36m.npy'),
dataset_name='pw3d',
data_prefix='data',
pipeline=test_pipeline,
ann_file='pw3d_test.npz'),
)
Loading

0 comments on commit 7996ee5

Please sign in to comment.