Skip to content

civat/DL-Theory-Practice

Repository files navigation

Deep Learning Theory and Practice

Hello


欢迎来到Deep Learning Theory and Practice!

仓库会尝试用尽可能统一的框架去实现一些常见CV任务中的经典网络。我们的目标是:仅用配置文件就能实现非常丰富的网络结构定义,而不依靠大量的代码。

目前仓库的代码框架进行了大量更新,原来老版本请使用Releases中的v0.1版本。老版本的代码不会再更新。

新版本中,新增了大量的wiki来介绍一些框架的基本信息和使用方法,相信能帮助大家更好地理解和使用本仓库代码。

Keep in mind:

  • 仓库所有代码是业务时间在写,精力有限,测试用例覆盖不全,可能会有bug,见谅的同时也欢迎指出。
  • 欢迎大家向仓库贡献代码或config文件。
  • 如果你有新的需求,欢迎提issue或知乎私信,我会尽量满足大家的需求。
  • 仓库有wiki,里面写了关于仓库的详细介绍,欢迎大家阅读。
  • 如果有帮助,希望给个star,谢谢!

如何使用


以分类任务为例,要训练某个网络,有两种方式:

  • 第一种方式是通过命令行参数指定配置文件路径(将xxx.yaml替换为配置文件路径):
python trainer_classification.py --config_file xxx.yaml

例如,当需要训练MobileNet v1时,可以使用:

python trainer_classification.py --config_file classification/configs/MobileNet_v1/MobileNet_ImageNet_224_EXP.yaml
  • 第二种方式是将trainer_classification.py中的config_file变量的默认值值指定为配置文件路径,然后直接运行trainer_classification.py即可。

对于分类任务,有两种方法来指定数据集:

  1. 通过关键字name来指定PyTorch内置的数据集。并通过root_path关键字来指定数据集存放的路径。当数据集不存在时,会自动下载到该目录下。目前只支持CIFAR10。

  2. 通过关键字trn_path和tst_path来分别指定训练集和测试集所在的根目录。在这种设置下,数据需要按照如下结构准备:

trn_path下包含多个文件夹,每一个文件夹表示不同类别。属于同一类别的图像存放在对应的文件夹下。

例如有一个区分猫狗的数据集。训练集在“cat_dog/”目录下。那么“cat_dog/”目录需要包含一个"cat"文件夹和一个"dog"文件夹。 "cat"文件夹中存放所有用于训练的猫的图像;"dog"文件夹中存放所有用于训练的狗的图像。

tst_path类似处理。

要训练GAN网络类似,只需要将trainer_classification.py替换为trainer_gan.py即可。GAN的训练数据要求所有图像都在同一目录下。

强烈建议大家先简要阅读一下wiki

已支持的网络结构

分类网络

没打钩的表示利用仓库中的Network和Block能通过配置文件实现的网络,但对应的config文件还没加入到仓库中。所有打钩的在仓库中都有对应配置文件。其目录在:classification/configs/。

GAN

config目录:gan/configs/

三方依赖:


列出的是已测试版本:

  • Pillow: 9.5.0
  • ptflops:0.7
  • pyyaml:6.0
  • PyTorch: 2.0.0
  • torchmetrics: 0.11.4
  • scipy: 1.5.1
  • opencv-python: 4.5.5

仓库适合谁?


  • 学生。本仓库代码不是按照特定网络结构case-by-case实现的,所以大家能学习到一些基本的关于框架的概念和设计技巧。
  • 想做可视化算法平台的工程师。仓库能大大降低算法的核心代码量,任何新的算法可能只需要定义一个新的Block和配置文件。
  • 算法爱好者。仓库代码给了详细的代码解释,尤其对于复杂的pipeline,能让大家更好理解代码。

更新日志


2023/03/05

更新GAN。为了框架统一,更改了大量之前代码。首先个人精力,测试用例不足,可能会有bug。欢迎指出。

2023/12/03

大更新。使用了更加复杂的设计来统一整体框架。框架的核心概念变为Network和Block。整体框架更加灵活和强大,但配置文件的写法也更加复杂。

To do list

  • 目标检测框架(进行中)
  • Stable Diffusion
  • ???

wiki


请记得仓库有wiki

可用的数据资源


CIFAR10

整理后的图像数据集地址(百度网盘):链接:https://pan.baidu.com/s/1VnHL3cSpQo-exU8m4OpMTA?pwd=bxvw 提取码:bxvw

ImageNet-1k

训练集:https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar

验证集:https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar

标签文件:https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz

数据准备参考此处:https://www.yii666.com/blog/339357.html