Skip to content

ATang0729/FashionAI-KeyPointsDetectionOfApparel

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

级联金字塔网络的Pytorch实现,用于时尚AI的关键点检测

English | 简体中文

这段代码在Pytorch中实现了用于多人姿势估计的级联金字塔网络,以检测服装的关键点,包括五种类型:上衣、礼服、外衣、裙子和裤子。这是在正式代码发布之前开始的,所以有一些差异。在实验中,测试了对CPN的各种修改,其中ResNet-152主干网和SENet-154主干网显示了更好的结果。

Example output of CPN with ResNet-152 backbone

开始

该中文版本的作者并非原作者。本文作者在复现2018年的原始版本过程中发现,原始版本并不适用于Python>=3.7(经测试发现对Python3.6兼容),因此对原始版本做出了略微修改和注释。本文档说明写于2023年2月28日,并非完全是对原文档的翻译,还有作者对原文档做出的补充和基于最近commit的修改。

The author of the Chinese version is not the original author. In the process of reproducing the 2018 original, the author found that the original does not work with Python>=3.7 (it was tested and found to be compatible with Python3.6), so minor changes were made and commented to the original. This documentation note, written on February 28, 2023, is not a complete translation of the original document, but contains additions to the original document and changes made by the author based on the recent commit.

环境

  • Python >=3.7

  • Numpy

  • Pandas

  • PyTorch

  • cv2

  • scikit-learn

  • py3nvml and nvidia-ml-py3

  • tqdm

  • random

  • math

数据准备

本版本的数据集来源(应该是2018年参赛的人后来上传的,原始路径不存在了):FashionAI dataset。数据集描述如下:

数据名称 大小 描述
README.md 9.56KB 数据文档
eval.zip 1.17KB
fashionAI_keypoints_train2.tar 2.10GB 训练集2
fashionAI_keypoints_train1.tar 3.00GB 训练集1
fashionAI_keypoints_test.tar 3.00GB 测试集
FashionAI_A_Hierarchical_Dataset_
for_Fashion_Understanding
674.88KB FashionAI数据集论文

⚠️在此版本中,数据集的路径与原始版本并不相同,且仅使用了原始数据集的5%。文件路径如下

fashion/
  |-- checkpoints
  |-- tmp
  |    |-- one
  |    |-- ensemble
  |-- kp_predictions
  |    |-- one
  |    |-- ensemble
  |-- KPDA/
       |-- test_extracted.csv
       |-- train_extracted.csv
       |-- train1/
       |    |-- blouse/
       |    |-- dress/
       |    |-- outwear/
       |-- train2/
       |    |-- skirt/
       |    |-- trousers/
       |-- test/
            |-- blouse/
            |-- dress/
            |-- outwear/
            |-- skirt/
            |-- trousers/
            |-- test.csv

fasion/是本版本用到的数据集的根目录

checkpoints是训练过程中保存的模型参数

tmp是训练过程中保存的临时文件

kp_predictions是训练过程中保存的预测结果

KPDA是原始数据集的存放目录

  • train1->fashionAI_keypoints_train1.tar
  • train2->fashionAI_keypoints_train2.tar
  • test->fashionAI_keypoints_test.tar

config.py中,可以通过修改proj_path来修改数据路径,包括数据读取路径和checkpoints、运行结果的保存路径

模型训练

超参数(batch size, cuda devices, learning rate ,workers,epoch)在config.py中进行修改

从零开始训练模型

python3 src/stage2/trainval.py -c {clothing type}

python src/stage2/trainval.py -c {clothing type}

使用-c或者--clothes来选择服装类型(blouse,dress,outwear,skirt,trouser中的一种)。

你也可以通过下列代码来自动运行

bash src/stage2/autorun.sh

它实际上为五种服装类型运行了stage2/trainval.py五次。

从checkpoints恢复训练

理解Checkpoint - 知乎

python3 src/stage2/trainval.py -c {clothing type} -r {path_to_the_checkpoint}

python src/stage2/trainval.py -c {clothing type} -r {path_to_the_checkpoint}

当恢复训练时,步数、学习率和优化器状态也将从checkpoints恢复。对于SGD优化器,优化器状态包含每个可训练参数的动量(momentum)。例如(代码见trainval.py line 187 to 193):

torch.save({
    'epoch': epoch,
    'save_dir': save_dir,
    'state_dict': state_dict,
    'lr': lr,
    'best_val_loss': best_val_loss},
    os.path.join(save_dir, 'kpt_' + config.clothes + '_%03d.ckpt' % epoch))

训练脚本的背后

  • 数据预处理在stage2/data_generator.py中进行,在训练中调用。

  • 本次挑战赛使用了两个网络,分别是 stage2/cascaded_pyramid_network.pystage2v9/cascaded_pyramid_network_v9.py。最后的分数来自于集合学习。这两个网络共享相同的架构,但骨架不同。

  • 所有其他版本都是失败的实验,可以暂时忽略。

模型验证和测试

基于验证集的模型验证

为了验证模型,请运行下列代码:

python3 src/stage2/predict_one.py -c {clothing type} -g {gpu index} -m {path/to/the/model} -v {True/False}

为了验证两个模型的综合性能,请运行:

python3 src/stage2/predict_ensemble.py -c {clothing type} -g {gpu index} -m1 {path/to/the/model1} -m2 {path/to/the/model2} -v {True/False}

在程序结束时,会打印出normalized error

在测试集上生成用于成绩提交的结果

测试单个模型,请运行:

python3 src/kpdetector/predict.py -c {clothing type} -g {gpu index} -m {path/to/the/model} -v {True/False}

测试两个模型的综合性能:

python3 src/kpdetector/predict_ensemble.py -c {clothing type} -g {gpu index} -m1 {path/to/the/model1} -m2 {path/to/the/model2} -v {True/False}

运行python3 src/kpdetector/concatenate_results.py以此将所有用于提交的.csv文件进行合并

实验(normalized error的降低)

  • Replace ResNet50 by ResNet152 as backbone network (-0.5%)
  • Increase input resolution from 256x256 to 512x512 (-2.5%)
  • Gaussian blur on predicted heatmap (-0.5%)
  • Reduce rotaton angle from 40 degree to 30 for data augmentation (-0.6%)
  • Use (x+2, y+2) where (x, y) is max value coordinate (-0.4%)
  • Use 1/4 offset from coordinate of the max value to the one of second max value (-0.2%)
  • Flip left to right for data augmentation (-0.2%)

Benchmark

This solution achieved LB 3.82% in Tianchi FashionAI Global Challenge, 17th place out 2322 teams. Check the leaderboard here.

About

FashionAI Key Points Detection using CPN model in Pytorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.5%
  • Shell 0.5%