diff --git a/docs/module_usage/tutorials/cv_modules/multi_object_tracking.md b/docs/module_usage/tutorials/cv_modules/multi_object_tracking.md
new file mode 100644
index 0000000000..359daec79e
--- /dev/null
+++ b/docs/module_usage/tutorials/cv_modules/multi_object_tracking.md
@@ -0,0 +1,190 @@
+简体中文 | [English](multi_object_tracking_en.md)
+
+# 多目标跟踪模块使用教程
+
+## 一、概述
+
+多目标跟踪任务是一项重要的研究内容,它涉及在视频序列中自动识别和跟踪多个感兴趣的目标。多目标跟踪要求在视频序列中同时跟踪多个目标对象,并获取它们的运动轨迹,同时保持目标的身份一致性。这些目标可以是行人、车辆、动物或其他任何类别的物体。视频监控、自动驾驶、无人机监测等领域具有重要意义。
+
+## 二、支持模型列表
+
+
+ 👉模型列表详情
+
+|模型|数据集|MOTA|IDF1|
+|-|-|-|-|
+|FairMOT-DLA-34|MOT-16 Training Set|83.2|83.1|
+|DeepSORT_PP-YOLOE_ResNet|MOT-17 half Val Set|56.7|64.6|
+|ByteTrack_PP-YOLOE_L|MOT-17 half train Set|50.4|59.7|
+
+**以上模型精度指标测量自 MOT-16和MOT-17 数据集。**
+
+
+
+## 三、快速集成
+> ❗ 在快速集成前,请先安装 PaddleX 的 wheel 包,详细请参考 [PaddleX本地安装教程](../../../installation/installation.md)
+
+完成wheel包的安装后,几行代码即可完成多目标跟踪模块的推理,可以任意切换该模块下的模型,您也可以将多目标跟踪的模块中的模型推理集成到您的项目中。
+运行以下代码前,请您下载[示例图片](待填充)到本地。
+```bash
+from paddlex.inference import create_model
+
+model_name = "ByteTrack_PP-YOLOE_L"
+
+model = create_model(model_name)
+output = model.predict("mot.png", batch_size=1)
+
+for res in output:
+ res.print(json_format=False)
+ res.save_to_img("./output/")
+ res.save_to_json("./output/res.json")
+```
+关于更多 PaddleX 的单模型推理的 API 的使用方法,可以参考[PaddleX单模型Python脚本使用说明](../../instructions/model_python_API.md)。
+
+## 四、二次开发
+如果你追求更高精度的现有模型,可以使用PaddleX的二次开发能力,开发更好的多目标跟踪模型。在使用PaddleX开发多目标跟踪模型之前,请务必安装PaddleDetection插件,安装过程可以参考[PaddleX本地安装教程](../../../installation/installation.md)。
+
+### 4.1 数据准备
+在进行模型训练前,需要准备相应任务模块的数据集。PaddleX 针对每一个模块提供了数据校验功能,**只有通过数据校验的数据才可以进行模型训练**。此外,PaddleX为每一个模块都提供了Demo数据集,您可以基于官方提供的 Demo 数据完成后续的开发。可以参考[PaddleX多目标跟踪任务模块数据标注教程](待填充)。
+
+#### 4.1.1 Demo 数据下载
+您可以参考下面的命令将 Demo 数据集下载到指定文件夹:
+
+```bash
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/mot_examples.tar -P ./dataset
+tar -xf ./dataset/mot_examples.tar -C ./dataset/
+```
+#### 4.1.2 数据校验
+一行命令即可完成数据校验:
+
+```bash
+python main.py -c paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml \
+ -o Global.mode=check_dataset \
+ -o Global.dataset_dir=./dataset/mot_examples
+```
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息,命令运行成功后会在log中打印出`Check dataset passed !`信息。校验结果文件保存在`./output/check_dataset_result.json`,同时相关产出会保存在当前目录的`./output/check_dataset`目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。
+
+
+ 👉 校验结果详情(点击展开)
+
+
+校验结果文件具体内容为:
+
+```bash
+{
+ "done_flag": true,
+ "check_pass": true,
+ "attributes": {
+ "num_classes": 1,
+ "train_samples": 5316,
+ "train_sample_paths": [
+ "check_dataset/demo_img/train/0.jpg",
+ "check_dataset/demo_img/train/1.jpg",
+ "check_dataset/demo_img/train/2.jpg",
+ ],
+ "val_samples": 5316,
+ "val_sample_paths": [
+ "check_dataset/demo_img/val/0.jpg",
+ "check_dataset/demo_img/val/1.jpg",
+ "check_dataset/demo_img/val/2.jpg",
+ ]
+ },
+ "analysis": null,
+ "dataset_path": "dataset/mot_datasetss",
+ "show_type": "image",
+ "dataset_type": "COCODetDataset"
+}
+```
+上述校验结果中,`check_pass` 为 `True` 表示数据集格式符合要求,其他部分指标的说明如下:
+
+* `attributes.train_samples`:该数据集训练集样本数量为 5316;
+* `attributes.val_samples`:该数据集验证集样本数量为 5316;
+* `attributes.train_sample_paths`:该数据集训练集样本可视化图片相对路径列表;
+* `attributes.val_sample_paths`:该数据集验证集样本可视化图片相对路径列表;
+
+
+### 4.2 模型训练
+一条命令即可完成模型的训练,以此处ByteTrack_PP-YOLOE_L的训练为例:
+
+```bash
+python main.py -c paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml \
+ -o Global.mode=train \
+ -o Global.dataset_dir=./dataset/mot_examples
+```
+需要如下几步:
+
+* 指定模型的`.yaml` 配置文件路径(此处为`ByteTrack_PP-YOLOE_L.yaml`)
+* 指定模式为模型训练:`-o Global.mode=train`
+* 指定训练数据集路径:`-o Global.dataset_dir`
+其他相关参数均可通过修改`.yaml`配置文件中的`Global`和`Train`下的字段来进行设置,也可以通过在命令行中追加参数来进行调整。如指定前 2 卡 gpu 训练:`-o Global.device=gpu:0,1`;设置训练轮次数为 10:`-o Train.epochs_iters=10`。更多可修改的参数及其详细解释,可以查阅查阅模型对应任务模块的配置文件说明[PaddleX通用模型配置文件参数说明](../../instructions/config_parameters_common.md)。
+
+
+ 👉 更多说明(点击展开)
+
+
+* 模型训练过程中,PaddleX 会自动保存模型权重文件,默认为`output`,如需指定保存路径,可通过配置文件中 `-o Global.output` 字段进行设置。
+* PaddleX 对您屏蔽了动态图权重和静态图权重的概念。在模型训练的过程中,会同时产出动态图和静态图的权重,在模型推理时,默认选择静态图权重推理。
+* 训练其他模型时,需要的指定相应的配置文件,模型和配置的文件的对应关系,可以查阅[PaddleX模型列表(CPU/GPU)](../../../support_list/models_list.md)。
+在完成模型训练后,所有产出保存在指定的输出目录(默认为`./output/`)下,通常有以下产出:
+
+* `train_result.json`:训练结果记录文件,记录了训练任务是否正常完成,以及产出的权重指标、相关文件路径等;
+* `train.log`:训练日志文件,记录了训练过程中的模型指标变化、loss 变化等;
+* `config.yaml`:训练配置文件,记录了本次训练的超参数的配置;
+* `.pdparams`、`.pdema`、`.pdopt.pdstate`、`.pdiparams`、`.pdmodel`:模型权重相关文件,包括网络参数、优化器、EMA、静态图网络参数、静态图网络结构等;
+
+
+### **4.3 模型评估**
+在完成模型训练后,可以对指定的模型权重文件在验证集上进行评估,验证模型精度。使用 PaddleX 进行模型评估,一条命令即可完成模型的评估:
+
+```bash
+python main.py -c paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml \
+ -o Global.mode=evaluate \
+ -o Global.dataset_dir=./dataset/mot_examples
+```
+与模型训练类似,需要如下几步:
+
+* 指定模型的`.yaml` 配置文件路径(此处为`ByteTrack_PP-YOLOE_L.yaml`)
+* 指定模式为模型评估:`-o Global.mode=evaluate`
+* 指定验证数据集路径:`-o Global.dataset_dir`
+其他相关参数均可通过修改`.yaml`配置文件中的`Global`和`Evaluate`下的字段来进行设置,详细请参考[PaddleX通用模型配置文件参数说明](../../instructions/config_parameters_common.md)。
+
+
+ 👉 更多说明(点击展开)
+
+
+在模型评估时,需要指定模型权重文件路径,每个配置文件中都内置了默认的权重保存路径,如需要改变,只需要通过追加命令行参数的形式进行设置即可,如`-o Evaluate.weight_path=./output/best_model/model.pdparams`。
+
+在完成模型评估后,会产出`evaluate_result.json,`记录评估的结果,具体来说,记录了评估任务是否正常完成,以及模型的评估指标,包含 MOTA;
+
+
+
+### **4.4 模型推理**
+在完成模型的训练和评估后,即可使用训练好的模型权重进行推理预测或者进行Python集成。
+
+#### 4.4.1 模型推理
+* 通过命令行的方式进行推理预测,只需如下一条命令,运行以下代码前,请您下载[示例图片](待填充)到本地。
+```bash
+python main.py -c paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml \
+ -o Global.mode=predict \
+ -o Predict.model_dir="./output/best_model/inference" \
+ -o Predict.input="mot.png"
+```
+与模型训练和评估类似,需要如下几步:
+
+* 指定模型的`.yaml` 配置文件路径(此处为`ByteTrack_PP-YOLOE_L.yaml`)
+* 指定模式为模型推理预测:`-o Global.mode=predict`
+* 指定模型权重路径:`-o Predict.model_dir="./output/best_model/inference"`
+* 指定输入数据路径:`-o Predict.input="..."`
+其他相关参数均可通过修改`.yaml`配置文件中的`Global`和`Predict`下的字段来进行设置,详细请参考[PaddleX通用模型配置文件参数说明](../../instructions/config_parameters_common.md)。
+
+#### 4.4.2 模型集成
+模型可以直接集成到 PaddleX 产线中,也可以直接集成到您自己的项目中。
+
+1.**产线集成**
+
+多目标跟踪模块可以集成的PaddleX产线有[多目标跟踪产线](待填充),只需要替换模型路径即可完成相关产线的多目标跟踪模块的模型更新。在产线集成中,你可以使用高性能部署和服务化部署来部署你得到的模型。
+
+2.**模块集成**
+
+您产出的权重可以直接集成到多目标跟踪模块中,可以参考[快速集成](#三快速集成)的 Python 示例代码,只需要将模型替换为你训练的到的模型路径即可。
diff --git a/paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml b/paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml
new file mode 100644
index 0000000000..c5ad95b7bf
--- /dev/null
+++ b/paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml
@@ -0,0 +1,42 @@
+Global:
+ model: ByteTrack_PP-YOLOE_L
+ mode: check_dataset # check_dataset/train/evaluate/predict
+ dataset_dir: "/paddle/dataset/paddlex/det/det_coco_examples"
+ device: gpu:0,1,2,3
+ output: "output"
+
+CheckDataset:
+ convert:
+ enable: False
+ src_dataset_type: null
+ split:
+ enable: False
+ train_percent: null
+ val_percent: null
+
+Train:
+ num_classes: 1
+ epochs_iters: 50
+ batch_size: #2
+ learning_rate: #0.0001
+ pretrain_weight_path: null
+ warmup_steps: #100
+ resume_path: null
+ log_interval: 10
+ eval_interval: 1
+
+Evaluate:
+ weight_path: "output/best_model/best_model.pdparams"
+ log_interval: 10
+
+Predict:
+ batch_size: 1
+ model_dir: "output/best_model/inference"
+ input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
+ kernel_option:
+ run_mode: paddle
+
+
+
+Export:
+ weight_path: # https://paddledet.bj.bcebos.com/models/detr_r50_1x_coco.pdparams
diff --git a/paddlex/configs/multi_object_tracking/DeepSORT_PP-YOLOE_ResNet.yaml b/paddlex/configs/multi_object_tracking/DeepSORT_PP-YOLOE_ResNet.yaml
new file mode 100644
index 0000000000..fe876eaeaf
--- /dev/null
+++ b/paddlex/configs/multi_object_tracking/DeepSORT_PP-YOLOE_ResNet.yaml
@@ -0,0 +1,42 @@
+Global:
+ model: DeepSORT_PP-YOLOE_ResNet
+ mode: check_dataset # check_dataset/train/evaluate/predict
+ dataset_dir: "/paddle/dataset/paddlex/det/det_coco_examples"
+ device: gpu:0,1,2,3
+ output: "output"
+
+CheckDataset:
+ convert:
+ enable: False
+ src_dataset_type: null
+ split:
+ enable: False
+ train_percent: null
+ val_percent: null
+
+Train:
+ num_classes: 1
+ epochs_iters: 50
+ batch_size: 2
+ learning_rate: #0.0001
+ pretrain_weight_path: null
+ warmup_steps: #100
+ resume_path: null
+ log_interval: 10
+ eval_interval: 1
+
+Evaluate:
+ weight_path: "output/best_model/best_model.pdparams"
+ log_interval: 10
+
+Predict:
+ batch_size: 1
+ model_dir: "output/best_model/inference"
+ input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
+ kernel_option:
+ run_mode: paddle
+
+
+
+Export:
+ weight_path: # https://paddledet.bj.bcebos.com/models/detr_r50_1x_coco.pdparams
diff --git a/paddlex/configs/multi_object_tracking/FairMOT-DLA-34.yaml b/paddlex/configs/multi_object_tracking/FairMOT-DLA-34.yaml
new file mode 100644
index 0000000000..2876daaff7
--- /dev/null
+++ b/paddlex/configs/multi_object_tracking/FairMOT-DLA-34.yaml
@@ -0,0 +1,42 @@
+Global:
+ model: FairMOT-DLA-34
+ mode: check_dataset # check_dataset/train/evaluate/predict
+ dataset_dir: "/paddle/dataset/paddlex/det/det_coco_examples"
+ device: gpu:0,1,2,3
+ output: "output"
+
+CheckDataset:
+ convert:
+ enable: False
+ src_dataset_type: null
+ split:
+ enable: False
+ train_percent: null
+ val_percent: null
+
+Train:
+ num_classes: 1
+ epochs_iters: # 50
+ batch_size: # 2
+ learning_rate: # 0.0001
+ pretrain_weight_path: null
+ warmup_steps: #100
+ resume_path: null
+ log_interval: 10
+ eval_interval: 1
+
+Evaluate:
+ weight_path: "output/best_model/best_model.pdparams"
+ log_interval: 10
+
+Predict:
+ batch_size: 1
+ model_dir: "output/best_model/inference"
+ input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
+ kernel_option:
+ run_mode: paddle
+
+
+
+Export:
+ weight_path: # https://paddledet.bj.bcebos.com/models/detr_r50_1x_coco.pdparams
diff --git a/paddlex/modules/__init__.py b/paddlex/modules/__init__.py
index 0369f2d584..318a70960a 100644
--- a/paddlex/modules/__init__.py
+++ b/paddlex/modules/__init__.py
@@ -95,4 +95,11 @@
TSCLSExportor,
)
+from .multi_object_tracking import (
+ MOTDatasetChecker,
+ MOTTrainer,
+ MOTEvaluator,
+ MOTExportor,
+)
+
from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator
diff --git a/paddlex/modules/multi_object_tracking/__init__.py b/paddlex/modules/multi_object_tracking/__init__.py
new file mode 100644
index 0000000000..fd897a48a2
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/__init__.py
@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .trainer import MOTTrainer
+from .dataset_checker import MOTDatasetChecker
+from .evaluator import MOTEvaluator
+from .exportor import MOTExportor
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/__init__.py b/paddlex/modules/multi_object_tracking/dataset_checker/__init__.py
new file mode 100644
index 0000000000..e008a8650c
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/dataset_checker/__init__.py
@@ -0,0 +1,115 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import os.path as osp
+from pathlib import Path
+from collections import defaultdict, Counter
+
+from PIL import Image
+import json
+from pycocotools.coco import COCO
+
+from ...base import BaseDatasetChecker
+from .dataset_src import check, convert, split_dataset, deep_analyse
+
+from ..model_list import MODELS
+
+
+class MOTDatasetChecker(BaseDatasetChecker):
+ """Dataset Checker for Multi-Object Tracking Model"""
+
+ entities = MODELS
+ sample_num = 10
+
+ def get_dataset_root(self, dataset_dir: str) -> str:
+ """find the dataset root dir
+
+ Args:
+ dataset_dir (str): the directory that contain dataset.
+
+ Returns:
+ str: the root directory of dataset.
+ """
+ # anno_dirs = list(Path(dataset_dir).glob("**/images"))
+ # assert len(anno_dirs) == 1
+ # dataset_dir = anno_dirs[0].parent.as_posix()
+ return dataset_dir
+
+ def convert_dataset(self, src_dataset_dir: str) -> str:
+ """convert the dataset from other type to specified type
+
+ Args:
+ src_dataset_dir (str): the root directory of dataset.
+
+ Returns:
+ str: the root directory of converted dataset.
+ """
+ return convert(
+ self.check_dataset_config.convert.src_dataset_type, src_dataset_dir
+ )
+
+ def split_dataset(self, src_dataset_dir: str) -> str:
+ """repartition the train and validation dataset
+
+ Args:
+ src_dataset_dir (str): the root directory of dataset.
+
+ Returns:
+ str: the root directory of splited dataset.
+ """
+ return split_dataset(
+ src_dataset_dir,
+ self.check_dataset_config.split.train_percent,
+ self.check_dataset_config.split.val_percent,
+ )
+
+ def check_dataset(self, dataset_dir: str, sample_num: int = sample_num) -> dict:
+ """check if the dataset meets the specifications and get dataset summary
+
+ Args:
+ dataset_dir (str): the root directory of dataset.
+ sample_num (int): the number to be sampled.
+ Returns:
+ dict: dataset summary.
+ """
+ return check(dataset_dir, self.output, self.global_config['model'])
+
+ def analyse(self, dataset_dir: str) -> dict:
+ """deep analyse dataset
+
+ Args:
+ dataset_dir (str): the root directory of dataset.
+
+ Returns:
+ dict: the deep analysis results.
+ """
+ return deep_analyse(dataset_dir, self.output, self.global_config['model'])
+
+ def get_show_type(self) -> str:
+ """get the show type of dataset
+
+ Returns:
+ str: show type
+ """
+ return "image"
+
+ def get_dataset_type(self) -> str:
+ """return the dataset type
+
+ Returns:
+ str: dataset type
+ """
+ return "COCODetDataset"
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/__init__.py b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/__init__.py
new file mode 100644
index 0000000000..7e577f3dba
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/__init__.py
@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .check_dataset import check
+from .convert_dataset import convert
+from .split_dataset import split_dataset
+from .analyse_dataset import deep_analyse
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/analyse_dataset.py b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/analyse_dataset.py
new file mode 100644
index 0000000000..6409a3a026
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/analyse_dataset.py
@@ -0,0 +1,83 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import json
+import platform
+from pathlib import Path
+from collections import defaultdict
+from PIL import Image
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import font_manager
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+from pycocotools.coco import COCO
+
+from .....utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def deep_analyse(dataset_dir, output, model_name):
+ """class analysis for dataset"""
+ if model_name == 'FairMOT-DLA-34':
+ return None
+ else:
+ tags = ["train", "val"]
+ all_instances = 0
+ for tag in tags:
+ annotations_path = os.path.abspath(
+ os.path.join(dataset_dir, f"annotations/{tag}_half.json")
+ )
+ labels_cnt = defaultdict(list)
+ coco = COCO(annotations_path)
+ cat_ids = coco.getCatIds()
+ for cat_id in cat_ids:
+ cat_name = coco.loadCats(ids=cat_id)[0]["name"]
+ labels_cnt[cat_name] = labels_cnt[cat_name] + coco.getAnnIds(catIds=cat_id)
+ all_instances += len(labels_cnt[cat_name])
+ if tag == "train":
+ cnts_train = [len(cat_ids) for cat_name, cat_ids in labels_cnt.items()]
+ elif tag == "val":
+ cnts_val = [len(cat_ids) for cat_name, cat_ids in labels_cnt.items()]
+ classes = [cat_name for cat_name, cat_ids in labels_cnt.items()]
+ sorted_id = sorted(
+ range(len(cnts_train)), key=lambda k: cnts_train[k], reverse=True
+ )
+ cnts_train_sorted = sorted(cnts_train, reverse=True)
+ cnts_val_sorted = [cnts_val[index] for index in sorted_id]
+ classes_sorted = [classes[index] for index in sorted_id]
+ x = np.arange(len(classes))
+ width = 0.5
+
+ # bar
+ os_system = platform.system().lower()
+ if os_system == "windows":
+ plt.rcParams["font.sans-serif"] = "FangSong"
+ else:
+ font = font_manager.FontProperties(fname=PINGFANG_FONT_FILE_PATH)
+ fig, ax = plt.subplots(figsize=(max(8, int(len(classes) / 5)), 5), dpi=120)
+ ax.bar(x, cnts_train_sorted, width=0.5, label="train")
+ ax.bar(x + width, cnts_val_sorted, width=0.5, label="val")
+ plt.xticks(
+ x + width / 2,
+ classes_sorted,
+ rotation=90,
+ fontproperties=None if os_system == "windows" else font,
+ )
+ ax.set_ylabel("Counts")
+ plt.legend()
+ fig.tight_layout()
+ fig_path = os.path.join(output, "histogram.png")
+ fig.savefig(fig_path)
+ return {"histogram": os.path.join("check_dataset", "histogram.png")}
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/check_dataset.py b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/check_dataset.py
new file mode 100644
index 0000000000..769e3ea7a7
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/check_dataset.py
@@ -0,0 +1,310 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import os.path as osp
+from collections import defaultdict, Counter
+from pathlib import Path
+import PIL
+from PIL import Image, ImageOps, ImageDraw, ImageFont
+import json
+from pycocotools.coco import COCO
+import numpy as np
+from .....utils.errors import DatasetFileNotFoundError
+from .utils.visualizer import draw_bbox
+from .....utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def check(dataset_dir, output, model_name):
+ """check dataset"""
+ dataset_dir = osp.abspath(dataset_dir)
+ if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
+ raise DatasetFileNotFoundError(file_path=dataset_dir)
+
+ sample_cnts = dict()
+ sample_paths = defaultdict(list)
+ im_sizes = defaultdict(Counter)
+ if model_name == 'FairMOT-DLA-34':
+ num_class = 1
+ tags = {'train':['mot17.train'], 'val':['mot17.train']}
+ for tag in tags.keys():
+ default_image_lists = tags[tag]
+ img_files = {}
+ samp_num = 0
+ for data_name in default_image_lists:
+ list_path = osp.join(dataset_dir, 'image_lists', data_name)
+ with open(list_path, 'r') as file:
+ img_files[data_name] = file.readlines()
+ img_files[data_name] = [
+ os.path.join(dataset_dir, x.strip())
+ for x in img_files[data_name]
+ ]
+ img_files[data_name] = list(
+ filter(lambda x: len(x) > 0, img_files[data_name]))
+
+ image_info = []
+ for data_name in img_files.keys():
+ samp_num += len(img_files[data_name])
+ image_info += img_files[data_name]
+
+ label_files_info = [
+ x.replace('images', 'labels_with_ids').replace(
+ '.png', '.txt').replace('.jpg', '.txt')
+ for x in image_info
+ ]
+
+ sample_num = min(10, samp_num)
+
+ for i in range(sample_num):
+ img_path = image_info[i]
+ label_path = label_files_info[i]
+ labels = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6) # [gt_class, gt_identity, cx, cy, w, h]
+
+ if not osp.exists(img_path):
+ raise DatasetFileNotFoundError(file_path=img_path)
+ img = Image.open(img_path)
+
+ labels[:, [2, 4]] *= img.width
+ labels[:, [3, 5]] *= img.height
+
+ img = ImageOps.exif_transpose(img)
+ vis_im = draw_bbox_with_labels(img, labels)
+ vis_save_dir = osp.join(output, "demo_img", tag)
+ vis_path = osp.join(vis_save_dir, str(i)+'.'+img_path.split('.')[1])
+ Path(vis_path).parent.mkdir(parents=True, exist_ok=True)
+ vis_im.save(vis_path)
+ sample_path = osp.join(
+ "check_dataset", os.path.relpath(vis_path, output)
+ )
+ sample_paths[tag].append(sample_path)
+
+ attrs = {}
+ attrs["num_classes"] = num_class
+ attrs["train_samples"] = samp_num
+ attrs["train_sample_paths"] = sample_paths['train']
+
+ attrs["val_samples"] = len(img_files['mot17.train'])
+ attrs["val_sample_paths"] = sample_paths['val']
+
+ else:
+ tags = ["train_half", "val_half"]
+ for _, tag in enumerate(tags):
+ file_list = osp.join(dataset_dir, f"annotations/{tag}.json")
+ if not osp.exists(file_list):
+ if tag in ("train_half", "val_half"):
+ # train and val file lists must exist
+ raise DatasetFileNotFoundError(
+ file_path=file_list,
+ solution=f"Ensure that both `train_half.json` and `val_half.json` exist in \
+ {dataset_dir}/annotations",
+ )
+ else:
+ continue
+ else:
+ with open(file_list, "r", encoding="utf-8") as f:
+ jsondata = json.load(f)
+
+ coco = COCO(file_list)
+ num_class = len(coco.getCatIds())
+
+ vis_save_dir = osp.join(output, "demo_img")
+
+ image_info = jsondata["images"]
+ sample_cnts[tag] = len(image_info)
+ sample_num = min(10, len(image_info))
+ img_types = {"train_half": 'train', "val_half": 'val'}
+ for i in range(sample_num):
+ file_name = image_info[i]["file_name"]
+ img_id = image_info[i]["id"]
+ img_type = img_types[tag]
+ img_path = osp.join(dataset_dir, "images", 'train', file_name)
+ if not osp.exists(img_path):
+ raise DatasetFileNotFoundError(file_path=img_path)
+ img = Image.open(img_path)
+ img = ImageOps.exif_transpose(img)
+ vis_im = draw_bbox(img, coco, img_id)
+ vis_path = osp.join(vis_save_dir, file_name)
+ Path(vis_path).parent.mkdir(parents=True, exist_ok=True)
+ vis_im.save(vis_path)
+ sample_path = osp.join(
+ "check_dataset", os.path.relpath(vis_path, output)
+ )
+ sample_paths[tag].append(sample_path)
+
+ attrs = {}
+ attrs["num_classes"] = num_class
+ attrs["train_samples"] = sample_cnts["train_half"]
+ attrs["train_sample_paths"] = sample_paths["train_half"]
+
+ attrs["val_samples"] = sample_cnts["val_half"]
+ attrs["val_sample_paths"] = sample_paths["val_half"]
+ return attrs
+
+def font_colormap(color_index):
+ """
+ Get font color according to the index of colormap
+ """
+ dark = np.array([0x14, 0x0E, 0x35])
+ light = np.array([0xFF, 0xFF, 0xFF])
+ light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+ if color_index in light_indexs:
+ return light.astype("int32")
+ else:
+ return dark.astype("int32")
+
+def colormap(rgb=False):
+ """
+ Get colormap
+
+ The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
+ utils/colormap.py
+ """
+ color_list = np.array(
+ [
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xCC,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0x66,
+ 0x00,
+ 0x66,
+ 0xFF,
+ 0xCC,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x4D,
+ 0x00,
+ 0x80,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0xB2,
+ 0x00,
+ 0x1A,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0xE5,
+ 0xFF,
+ 0x99,
+ 0x00,
+ 0x33,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x33,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0x99,
+ 0xFF,
+ 0xE5,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0x1A,
+ 0x00,
+ 0xB2,
+ 0xFF,
+ 0x80,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0x4D,
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3))
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list.astype("int32")
+
+def draw_bbox_with_labels(image, label):
+ """
+ Draw bbox on image
+ """
+ font_size = 12
+ font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+
+ image = image.convert("RGB")
+ draw = ImageDraw.Draw(image)
+ image_size = image.size
+ width = int(max(image_size) * 0.005)
+
+ catid2color = {}
+ catid2fontcolor = {}
+ catid_num_dict = {}
+ color_list = colormap(rgb=True)
+ annotations = label
+
+ for ann in annotations:
+ catid = int(ann[1])
+ catid_num_dict[catid] = catid_num_dict.get(catid, 0) + 1
+ for i, (catid, _) in enumerate(
+ sorted(catid_num_dict.items(), key=lambda x: x[1], reverse=True)
+ ):
+ if catid not in catid2color:
+ color_index = i % len(color_list)
+ catid2color[catid] = color_list[color_index]
+ catid2fontcolor[catid] = font_colormap(color_index)
+
+
+ for ann in annotations:
+ catid = int(ann[1])
+ bbox = ann[2:]
+ color = tuple(catid2color[catid])
+ font_color = tuple(catid2fontcolor[catid])
+
+ if len(bbox) == 4:
+ # draw bbox
+ cx, cy, w, h = bbox
+ xmin = cx - 0.5*w
+ ymin = cy - 0.5*h
+ xmax = cx + 0.5*w
+ ymax = cy + 0.5*h
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
+ width=width,
+ fill=color,
+ )
+ else:
+ logging.info("Error: The shape of bbox must be [M, 4]!")
+
+ # draw label
+ label = 'targets_' + str(catid)
+ text = "{}".format(label)
+ if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+ tw, th = draw.textsize(text, font=font)
+ else:
+ left, top, right, bottom = draw.textbbox((0, 0), text, font)
+ tw, th = right - left, bottom - top
+
+ if ymin < th:
+ draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
+ draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
+ else:
+ draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
+ draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
+
+ return image
\ No newline at end of file
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/convert_dataset.py b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/convert_dataset.py
new file mode 100644
index 0000000000..5c30c1588b
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/convert_dataset.py
@@ -0,0 +1,433 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+import json
+import random
+import xml.etree.ElementTree as ET
+from tqdm import tqdm
+
+from .....utils.file_interface import custom_open, write_json_file
+from .....utils.errors import ConvertFailedError
+from .....utils.logging import info, warning
+
+
+class Indexer(object):
+ """Indexer"""
+
+ def __init__(self):
+ """init indexer"""
+ self._map = {}
+ self.idx = 0
+
+ def get_id(self, key):
+ """get id by key"""
+ if key not in self._map:
+ self.idx += 1
+ self._map[key] = self.idx
+ return self._map[key]
+
+ def get_list(self, key_name):
+ """return list containing key and id"""
+ map_list = []
+ for key in self._map:
+ val = self._map[key]
+ map_list.append({key_name: key, "id": val})
+ return map_list
+
+
+class Extension(object):
+ """Extension"""
+
+ def __init__(self, exts_list):
+ """init extension"""
+ self._exts_list = ["." + ext for ext in exts_list]
+
+ def __iter__(self):
+ """iterator"""
+ return iter(self._exts_list)
+
+ def update(self, ext):
+ """update extension"""
+ self._exts_list.remove(ext)
+ self._exts_list.insert(0, ext)
+
+
+def check_src_dataset(root_dir, dataset_type):
+ """check src dataset format validity"""
+ if dataset_type in ("VOC", "VOCWithUnlabeled"):
+ anno_suffix = ".xml"
+ elif dataset_type in ("LabelMe", "LabelMeWithUnlabeled"):
+ anno_suffix = ".json"
+ else:
+ raise ConvertFailedError(
+ message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 VOC、LabelMe 和 VOCWithUnlabeled、LabelMeWithUnlabeled 格式。"
+ )
+
+ err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
+
+ anno_map = {}
+ for dst_anno, src_anno in [
+ ("instance_train.json", "train_anno_list.txt"),
+ ("instance_val.json", "val_anno_list.txt"),
+ ]:
+ src_anno_path = os.path.join(root_dir, src_anno)
+ if not os.path.exists(src_anno_path):
+ if dst_anno == "instance_train.json":
+ raise ConvertFailedError(
+ message=f"{err_msg_prefix}保证{src_anno_path}文件存在。"
+ )
+ continue
+ with custom_open(src_anno_path, "r") as f:
+ anno_list = f.readlines()
+ for anno_fn in anno_list:
+ anno_fn = anno_fn.strip().split(" ")[-1]
+ anno_path = os.path.join(root_dir, anno_fn)
+ if not os.path.exists(anno_path):
+ raise ConvertFailedError(
+ message=f'{err_msg_prefix}保证"{src_anno_path}"中的"{anno_fn}"文件存在。'
+ )
+ anno_map[dst_anno] = src_anno_path
+ return anno_map
+
+
+def convert(dataset_type, input_dir):
+ """convert dataset to coco format"""
+ # check format validity
+ anno_map = check_src_dataset(input_dir, dataset_type)
+ (
+ convert_voc_dataset(input_dir, anno_map)
+ if dataset_type in ("VOC", "VOCWithUnlabeled")
+ else convert_labelme_dataset(input_dir, anno_map)
+ )
+
+
+def split_anno_list(root_dir, anno_map):
+ """Split anno list to 80% train and 20% val"""
+
+ train_anno_list = []
+ val_anno_list = []
+ anno_list_bak = os.path.join(root_dir, "train_anno_list.txt.bak")
+ shutil.move(anno_map["instance_train.json"], anno_list_bak),
+ with custom_open(anno_list_bak, "r") as f:
+ src_anno = f.readlines()
+ random.shuffle(src_anno)
+ train_anno_list = src_anno[: int(len(src_anno) * 0.8)]
+ val_anno_list = src_anno[int(len(src_anno) * 0.8) :]
+ with custom_open(os.path.join(root_dir, "train_anno_list.txt"), "w") as f:
+ f.writelines(train_anno_list)
+ with custom_open(os.path.join(root_dir, "val_anno_list.txt"), "w") as f:
+ f.writelines(val_anno_list)
+ anno_map["instance_train.json"] = os.path.join(root_dir, "train_anno_list.txt")
+ anno_map["instance_val.json"] = os.path.join(root_dir, "val_anno_list.txt")
+ msg = f"{os.path.join(root_dir,'val_anno_list.txt')}不存在,数据集已默认按照80%训练集,20%验证集划分,\
+ 且将原始'train_anno_list.txt'重命名为'train_anno_list.txt.bak'."
+
+ warning(msg)
+ return anno_map
+
+
+def convert_labelme_dataset(root_dir, anno_map):
+ """convert dataset labeled by LabelMe to coco format"""
+ label_indexer = Indexer()
+ img_indexer = Indexer()
+
+ annotations_dir = os.path.join(root_dir, "annotations")
+ if not os.path.exists(annotations_dir):
+ os.makedirs(annotations_dir)
+
+ # FIXME(gaotingquan): support lmssl
+ unlabeled_path = os.path.join(root_dir, "unlabeled.txt")
+ if os.path.exists(unlabeled_path):
+ shutil.move(unlabeled_path, os.path.join(annotations_dir, "unlabeled.txt"))
+
+ # 不存在val_anno_list,对原始数据集进行划分
+ if "instance_val.json" not in anno_map:
+ anno_map = split_anno_list(root_dir, anno_map)
+
+ for dst_anno in anno_map:
+ labelme2coco(
+ label_indexer,
+ img_indexer,
+ root_dir,
+ anno_map[dst_anno],
+ os.path.join(annotations_dir, dst_anno),
+ )
+
+
+def labelme2coco(label_indexer, img_indexer, root_dir, anno_path, save_path):
+ """convert json files generated by LabelMe to coco format and save to files"""
+ with custom_open(anno_path, "r") as f:
+ json_list = f.readlines()
+
+ anno_num = 0
+ anno_list = []
+ image_list = []
+ info(f"Start loading json annotation files from {anno_path} ...")
+ for json_path in tqdm(json_list):
+ json_path = json_path.strip()
+ if not json_path.endswith(".json"):
+ info(f'An illegal json path("{json_path}") found! Has been ignored.')
+ continue
+
+ with custom_open(os.path.join(root_dir, json_path.strip()), "r") as f:
+ labelme_data = json.load(f)
+
+ img_id = img_indexer.get_id(labelme_data["imagePath"])
+ image_list.append(
+ {
+ "id": img_id,
+ "file_name": labelme_data["imagePath"].split("/")[-1],
+ "width": labelme_data["imageWidth"],
+ "height": labelme_data["imageHeight"],
+ }
+ )
+
+ for shape in labelme_data["shapes"]:
+ assert shape["shape_type"] == "rectangle", "Only rectangle are supported."
+ category_id = label_indexer.get_id(shape["label"])
+ (x1, y1), (x2, y2) = shape["points"]
+ x1, x2 = sorted([x1, x2])
+ y1, y2 = sorted([y1, y2])
+ bbox = list(map(float, [x1, y1, x2 - x1, y2 - y1]))
+ anno_num += 1
+ anno_list.append(
+ {
+ "image_id": img_id,
+ "bbox": bbox,
+ "category_id": category_id,
+ "id": anno_num,
+ "iscrowd": 0,
+ "area": bbox[2] * bbox[3],
+ "ignore": 0,
+ }
+ )
+
+ category_list = label_indexer.get_list(key_name="name")
+ data_coco = {
+ "images": image_list,
+ "categories": category_list,
+ "annotations": anno_list,
+ }
+
+ write_json_file(data_coco, save_path)
+ info(f"The converted annotations has been save to {save_path}.")
+
+
+def convert_voc_dataset(root_dir, anno_map):
+ """convert VOC format dataset to coco format"""
+ label_indexer = Indexer()
+ img_indexer = Indexer()
+
+ annotations_dir = os.path.join(root_dir, "annotations")
+ if not os.path.exists(annotations_dir):
+ os.makedirs(annotations_dir)
+
+ # FIXME(gaotingquan): support lmssl
+ unlabeled_path = os.path.join(root_dir, "unlabeled.txt")
+ if os.path.exists(unlabeled_path):
+ shutil.move(unlabeled_path, os.path.join(annotations_dir, "unlabeled.txt"))
+
+ # 不存在val_anno_list,对原始数据集进行划分
+ if "instance_val.json" not in anno_map:
+ anno_map = split_anno_list(root_dir, anno_map)
+
+ for dst_anno in anno_map:
+ ann_paths = voc_get_label_anno(root_dir, anno_map[dst_anno])
+
+ voc_xmls_to_cocojson(
+ root_dir=root_dir,
+ annotation_paths=ann_paths,
+ label_indexer=label_indexer,
+ img_indexer=img_indexer,
+ output=annotations_dir,
+ output_file=dst_anno,
+ )
+
+
+def voc_get_label_anno(root_dir, anno_path):
+ """
+ Read VOC format annotation file.
+
+ Args:
+ root_dir (str): The directoty of VOC annotation file.
+ anno_path (str): The annoation file path.
+
+ Returns:
+ tuple: A tuple of two elements, the first of which is of type dict, representing the mapping between tag names
+ and their corresponding ids, and the second of type list, representing the list of paths to all annotated files.
+ """
+ if not os.path.exists(anno_path):
+ info(f"The annotation file {anno_path} don't exists, has been ignored!")
+ return []
+ with custom_open(anno_path, "r") as f:
+ ann_ids = f.readlines()
+
+ ann_paths = []
+ info(f"Start loading xml annotation files from {anno_path} ...")
+ for aid in ann_ids:
+ aid = aid.strip().split(" ")[-1]
+ if not aid.endswith(".xml"):
+ info(f'An illegal xml path("{aid}") found! Has been ignored.')
+ continue
+ ann_path = os.path.join(root_dir, aid)
+ ann_paths.append(ann_path)
+
+ return ann_paths
+
+
+def voc_get_image_info(annotation_root, img_indexer):
+ """
+ Get the iamge info from VOC annotation file.
+
+ Args:
+ annotation_root: The annotation root.
+ img_indexer: indexer to get image id by filename.
+
+ Returns:
+ dict: The image info.
+
+ Raises:
+ AssertionError: When filename cannot be found in 'annotation_root'.
+ """
+ filename = annotation_root.findtext("filename")
+ assert filename is not None, filename
+ img_name = os.path.basename(filename)
+ im_id = img_indexer.get_id(filename)
+
+ size = annotation_root.find("size")
+ width = float(size.findtext("width"))
+ height = float(size.findtext("height"))
+
+ image_info = {"file_name": filename, "height": height, "width": width, "id": im_id}
+ return image_info
+
+
+def voc_get_coco_annotation(obj, label_indexer):
+ """
+ Convert VOC format annotation to COCO format.
+
+ Args:
+ obj: a obj in VOC.
+ label_indexer: indexer to get category id by label name.
+
+ Returns:
+ dict: A dict with the COCO format annotation info.
+
+ Raises:
+ AssertionError: When the width or height of the annotation box is illegal.
+ """
+ label = obj.findtext("name")
+ category_id = label_indexer.get_id(label)
+ bndbox = obj.find("bndbox")
+ xmin = float(bndbox.findtext("xmin"))
+ ymin = float(bndbox.findtext("ymin"))
+ xmax = float(bndbox.findtext("xmax"))
+ ymax = float(bndbox.findtext("ymax"))
+ if xmin > xmax or ymin > ymax:
+ temp = xmin
+ xmin = min(xmin, xmax)
+ xmax = max(temp, xmax)
+ temp = ymin
+ ymin = min(ymin, ymax)
+ ymax = max(temp, ymax)
+
+ o_width = xmax - xmin
+ o_height = ymax - ymin
+ anno = {
+ "area": o_width * o_height,
+ "iscrowd": 0,
+ "bbox": [xmin, ymin, o_width, o_height],
+ "category_id": category_id,
+ "ignore": 0,
+ }
+ return anno
+
+
+def voc_xmls_to_cocojson(
+ root_dir, annotation_paths, label_indexer, img_indexer, output, output_file
+):
+ """
+ Convert VOC format data to COCO format.
+
+ Args:
+ annotation_paths (list): A list of paths to the XML files.
+ label_indexer: indexer to get category id by label name.
+ img_indexer: indexer to get image id by filename.
+ output (str): The directory to save output JSON file.
+ output_file (str): Output JSON file name.
+
+ Returns:
+ None
+ """
+ extension_list = ["jpg", "png", "jpeg", "JPG", "PNG", "JPEG"]
+ suffixs = Extension(extension_list)
+
+ def match(root_dir, prefilename, prexlm_name):
+ """matching extension"""
+ for ext in suffixs:
+ if os.path.exists(os.path.join(root_dir, "images", prefilename + ext)):
+ suffixs.update(ext)
+ return prefilename + ext
+ elif os.path.exists(os.path.join(root_dir, "images", prexlm_name + ext)):
+ suffixs.update(ext)
+ return prexlm_name + ext
+ return None
+
+ output_json_dict = {
+ "images": [],
+ "type": "instances",
+ "annotations": [],
+ "categories": [],
+ }
+
+ bnd_id = 1 # bounding box start id
+ info("Start converting !")
+ for a_path in tqdm(annotation_paths):
+ # Read annotation xml
+ ann_tree = ET.parse(a_path)
+ ann_root = ann_tree.getroot()
+ file_name = ann_root.find("filename")
+ prefile_name = file_name.text.split(".")[0]
+ prexlm_name = os.path.basename(a_path).split(".")[0]
+ # 根据file_name 和 xlm_name 分别匹配查找图片
+ f_name = match(root_dir, prefile_name, prexlm_name)
+ if f_name is not None:
+ file_name.text = f_name
+ else:
+ prefile_name_set = set({prefile_name, prexlm_name})
+ prefile_name_set = ",".join(prefile_name_set)
+ suffix_set = ",".join(extension_list)
+ images_path = os.path.join(root_dir, "images")
+ info(
+ f"{images_path}/{{{prefile_name_set}}}.{{{suffix_set}}} both not exists,will be skipped."
+ )
+ continue
+ img_info = voc_get_image_info(ann_root, img_indexer)
+ output_json_dict["images"].append(img_info)
+
+ for obj in ann_root.findall("object"):
+ if obj.find("bndbox") is None: # Skip the ojbect wihtout bndbox
+ continue
+ ann = voc_get_coco_annotation(obj=obj, label_indexer=label_indexer)
+ ann.update({"image_id": img_info["id"], "id": bnd_id})
+ output_json_dict["annotations"].append(ann)
+ bnd_id = bnd_id + 1
+
+ output_json_dict["categories"] = label_indexer.get_list(key_name="name")
+ output_file = os.path.join(output, output_file)
+ write_json_file(output_json_dict, output_file)
+ info(f"The converted annotations has been save to {output_file}.")
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/split_dataset.py b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/split_dataset.py
new file mode 100644
index 0000000000..b460edd4b0
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/split_dataset.py
@@ -0,0 +1,119 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import shutil
+import random
+import json
+from tqdm import tqdm
+
+from .....utils.file_interface import custom_open, write_json_file
+from .....utils.logging import info
+
+
+def split_dataset(root_dir, train_rate, val_rate):
+ """split dataset"""
+ assert (
+ train_rate + val_rate == 100
+ ), f"The sum of train_rate({train_rate}), val_rate({val_rate}) should equal 100!"
+ assert (
+ train_rate > 0 and val_rate > 0
+ ), f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
+
+ all_image_info_list = []
+ all_category_dict = {}
+ max_image_id = 0
+ for fn in ["instance_train.json", "instance_val.json"]:
+ anno_path = os.path.join(root_dir, "annotations", fn)
+ if not os.path.exists(anno_path):
+ info(f"The annotation file {anno_path} don't exists, has been ignored!")
+ continue
+ image_info_list, category_list, max_image_id = json2list(
+ anno_path, max_image_id
+ )
+ all_image_info_list.extend(image_info_list)
+
+ for category in category_list:
+ if category["id"] not in all_category_dict:
+ all_category_dict[category["id"]] = category
+
+ total_num = len(all_image_info_list)
+ random.shuffle(all_image_info_list)
+
+ all_category_list = [all_category_dict[k] for k in all_category_dict]
+
+ start = 0
+ for fn, rate in [
+ ("instance_train.json", train_rate),
+ ("instance_val.json", val_rate),
+ ]:
+ end = start + round(total_num * rate / 100)
+ save_path = os.path.join(root_dir, "annotations", fn)
+ if os.path.exists(save_path):
+ bak_path = save_path + ".bak"
+ shutil.move(save_path, bak_path)
+ info(f"The original annotation file {fn} has been backed up to {bak_path}.")
+ assemble_write(all_image_info_list[start:end], all_category_list, save_path)
+ start = end
+ return root_dir
+
+
+def json2list(json_path, base_image_num):
+ """load json as list"""
+ assert os.path.exists(json_path), json_path
+ with custom_open(json_path, "r") as f:
+ data = json.load(f)
+ image_info_dict = {}
+ max_image_id = 0
+ for image_info in data["images"]:
+ # 得到全局唯一的image_id
+ global_image_id = image_info["id"] + base_image_num
+ max_image_id = max(global_image_id, max_image_id)
+ image_info["id"] = global_image_id
+ image_info_dict[global_image_id] = {"img": image_info, "anno": []}
+
+ image_info_dict = {
+ image_info["id"]: {"img": image_info, "anno": []}
+ for image_info in data["images"]
+ }
+ info(f"Start loading annotation file {json_path}...")
+ for anno in tqdm(data["annotations"]):
+ global_image_id = anno["image_id"] + base_image_num
+ anno["image_id"] = global_image_id
+ image_info_dict[global_image_id]["anno"].append(anno)
+ image_info_list = [
+ (image_info_dict[image_info]["img"], image_info_dict[image_info]["anno"])
+ for image_info in image_info_dict
+ ]
+ return image_info_list, data["categories"], max_image_id
+
+
+def assemble_write(image_info_list, category_list, save_path):
+ """assemble coco format and save to file"""
+ coco_data = {"categories": category_list}
+ image_list = [i[0] for i in image_info_list]
+ all_anno_list = []
+ for i in image_info_list:
+ all_anno_list.extend(i[1])
+ anno_list = []
+ for i, anno in enumerate(all_anno_list):
+ anno["id"] = i + 1
+ anno_list.append(anno)
+
+ coco_data["images"] = image_list
+ coco_data["annotations"] = anno_list
+
+ write_json_file(coco_data, save_path)
+ info(f"The splited annotations has been save to {save_path}.")
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/utils/__init__.py b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/utils/__init__.py
new file mode 100644
index 0000000000..59372f9379
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/utils/__init__.py
@@ -0,0 +1,13 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/utils/visualizer.py b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/utils/visualizer.py
new file mode 100644
index 0000000000..6029a071c1
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/utils/visualizer.py
@@ -0,0 +1,187 @@
+# -*- coding: UTF-8 -*-
+################################################################################
+#
+# Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
+#
+################################################################################
+"""
+Author: PaddlePaddle Authors
+"""
+import os
+import numpy as np
+import json
+from pathlib import Path
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+from pycocotools.coco import COCO
+
+from ......utils.fonts import PINGFANG_FONT_FILE_PATH
+from ......utils import logging
+
+
+def colormap(rgb=False):
+ """
+ Get colormap
+
+ The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
+utils/colormap.py
+ """
+ color_list = np.array(
+ [
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xCC,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0x66,
+ 0x00,
+ 0x66,
+ 0xFF,
+ 0xCC,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x4D,
+ 0x00,
+ 0x80,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0xB2,
+ 0x00,
+ 0x1A,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0xE5,
+ 0xFF,
+ 0x99,
+ 0x00,
+ 0x33,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x33,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0x99,
+ 0xFF,
+ 0xE5,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0x1A,
+ 0x00,
+ 0xB2,
+ 0xFF,
+ 0x80,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0x4D,
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3))
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list.astype("int32")
+
+
+def font_colormap(color_index):
+ """
+ Get font color according to the index of colormap
+ """
+ dark = np.array([0x14, 0x0E, 0x35])
+ light = np.array([0xFF, 0xFF, 0xFF])
+ light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+ if color_index in light_indexs:
+ return light.astype("int32")
+ else:
+ return dark.astype("int32")
+
+
+def draw_bbox(image, coco_info: COCO, img_id):
+ """
+ Draw bbox on image
+ """
+ try:
+ image_info = coco_info.loadImgs(img_id)[0]
+ font_size = int(0.024 * int(image_info["width"])) + 2
+ except:
+ font_size = 12
+ font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+
+ image = image.convert("RGB")
+ draw = ImageDraw.Draw(image)
+ image_size = image.size
+ width = int(max(image_size) * 0.005)
+
+ catid2color = {}
+ catid2fontcolor = {}
+ catid_num_dict = {}
+ color_list = colormap(rgb=True)
+ annotations = coco_info.loadAnns(coco_info.getAnnIds(imgIds=img_id))
+
+ for ann in annotations:
+ catid = ann["category_id"]
+ catid_num_dict[catid] = catid_num_dict.get(catid, 0) + 1
+ for i, (catid, _) in enumerate(
+ sorted(catid_num_dict.items(), key=lambda x: x[1], reverse=True)
+ ):
+ if catid not in catid2color:
+ color_index = i % len(color_list)
+ catid2color[catid] = color_list[color_index]
+ catid2fontcolor[catid] = font_colormap(color_index)
+ for ann in annotations:
+ catid, bbox = ann["category_id"], ann["bbox"]
+ color = tuple(catid2color[catid])
+ font_color = tuple(catid2fontcolor[catid])
+
+ if len(bbox) == 4:
+ # draw bbox
+ xmin, ymin, w, h = bbox
+ xmax = xmin + w
+ ymax = ymin + h
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
+ width=width,
+ fill=color,
+ )
+ elif len(bbox) == 8:
+ x1, y1, x2, y2, x3, y3, x4, y4 = bbox
+ draw.line(
+ [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
+ width=width,
+ fill=color,
+ )
+ xmin = min(x1, x2, x3, x4)
+ ymin = min(y1, y2, y3, y4)
+ else:
+ logging.info("Error: The shape of bbox must be [M, 4] or [M, 8]!")
+
+ # draw label
+ label = coco_info.loadCats(catid)[0]["name"]
+ text = "{}".format(label)
+ if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+ tw, th = draw.textsize(text, font=font)
+ else:
+ left, top, right, bottom = draw.textbbox((0, 0), text, font)
+ tw, th = right - left, bottom - top
+
+ if ymin < th:
+ draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
+ draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
+ else:
+ draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
+ draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
+
+ return image
diff --git a/paddlex/modules/multi_object_tracking/evaluator.py b/paddlex/modules/multi_object_tracking/evaluator.py
new file mode 100644
index 0000000000..9aa5f0f7a9
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/evaluator.py
@@ -0,0 +1,41 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from ..base import BaseEvaluator
+from .model_list import MODELS
+
+
+class MOTEvaluator(BaseEvaluator):
+ """Object Detection Model Evaluator"""
+
+ entities = MODELS
+
+ def update_config(self):
+ """update evalution config"""
+ if self.eval_config.log_interval:
+ self.pdx_config.update_log_interval(self.eval_config.log_interval)
+ self.pdx_config.update_dataset(self.global_config.dataset_dir, 'MOTImageFolder')
+ self.pdx_config.update_weights(self.eval_config.weight_path)
+
+ def get_eval_kwargs(self) -> dict:
+ """get key-value arguments of model evalution function
+
+ Returns:
+ dict: the arguments of evaluation function.
+ """
+ return {
+ "weight_path": self.eval_config.weight_path,
+ "device": self.get_device(using_device_number=1),
+ }
diff --git a/paddlex/modules/multi_object_tracking/exportor.py b/paddlex/modules/multi_object_tracking/exportor.py
new file mode 100644
index 0000000000..27b357b84b
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/exportor.py
@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class MOTExportor(BaseExportor):
+ """Object Detection Model Exportor"""
+
+ entities = MODELS
diff --git a/paddlex/modules/multi_object_tracking/model_list.py b/paddlex/modules/multi_object_tracking/model_list.py
new file mode 100644
index 0000000000..a78dd5232d
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/model_list.py
@@ -0,0 +1,20 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+MODELS = [
+ "ByteTrack_PP-YOLOE_L",
+ "DeepSORT_PP-YOLOE_ResNet",
+ "FairMOT-DLA-34",
+]
diff --git a/paddlex/modules/multi_object_tracking/trainer.py b/paddlex/modules/multi_object_tracking/trainer.py
new file mode 100644
index 0000000000..fec7bb5113
--- /dev/null
+++ b/paddlex/modules/multi_object_tracking/trainer.py
@@ -0,0 +1,86 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from pathlib import Path
+import lazy_paddle as paddle
+
+from ..base import BaseTrainer
+from ...utils.config import AttrDict
+from ...utils import logging
+from .model_list import MODELS
+
+
+class MOTTrainer(BaseTrainer):
+ """Object Detection Model Trainer"""
+
+ entities = MODELS
+
+ def _update_dataset(self):
+ """update dataset settings"""
+ self.pdx_config.update_dataset(self.global_config.dataset_dir, 'MOTDataSet')
+
+ def update_config(self):
+ """update training config"""
+ if self.train_config.log_interval:
+ self.pdx_config.update_log_interval(self.train_config.log_interval)
+ if self.train_config.eval_interval:
+ self.pdx_config.update_eval_interval(self.train_config.eval_interval)
+
+ self._update_dataset()
+
+ if self.train_config.num_classes is not None:
+ self.pdx_config.update_num_class(self.train_config.num_classes)
+ if (
+ self.train_config.pretrain_weight_path
+ and self.train_config.pretrain_weight_path != ""
+ ):
+ self.pdx_config.update_pretrained_weights(
+ self.train_config.pretrain_weight_path
+ )
+ if self.train_config.batch_size is not None:
+ self.pdx_config.update_batch_size(self.train_config.batch_size)
+ if self.train_config.learning_rate is not None:
+ self.pdx_config.update_learning_rate(self.train_config.learning_rate)
+ if self.train_config.epochs_iters is not None:
+ self.pdx_config.update_epochs(self.train_config.epochs_iters)
+ epochs_iters = self.train_config.epochs_iters
+ else:
+ epochs_iters = self.pdx_config.get_epochs_iters()
+ if self.global_config.output is not None:
+ self.pdx_config.update_save_dir(self.global_config.output)
+
+ if "PicoDet" in self.global_config.model:
+ assigner_epochs = max(int(epochs_iters / 10), 1)
+ try:
+ self.pdx_config.update_static_assigner_epochs(assigner_epochs)
+ except Exception:
+ logging.info(
+ f"The model({self.global_config.model}) don't support to update_static_assigner_epochs!"
+ )
+
+ def get_train_kwargs(self) -> dict:
+ """get key-value arguments of model training function
+
+ Returns:
+ dict: the arguments of training function.
+ """
+ train_args = {"device": self.get_device()}
+ if (
+ self.train_config.resume_path is not None
+ and self.train_config.resume_path != ""
+ ):
+ train_args["resume_path"] = self.train_config.resume_path
+ train_args["dy2st"] = self.train_config.get("dy2st", False)
+ return train_args
diff --git a/paddlex/repo_apis/PaddleDetection_api/__init__.py b/paddlex/repo_apis/PaddleDetection_api/__init__.py
index 7cd87ab325..c276ea80e1 100644
--- a/paddlex/repo_apis/PaddleDetection_api/__init__.py
+++ b/paddlex/repo_apis/PaddleDetection_api/__init__.py
@@ -15,3 +15,4 @@
from .object_det import DetModel, DetRunner, register
from .instance_seg import InstanceSegModel, InstanceSegRunner, register
+from .mot import MOTModel, MOTRunner, register
\ No newline at end of file
diff --git a/paddlex/repo_apis/PaddleDetection_api/configs/ByteTrack_PP-YOLOE_L.yaml b/paddlex/repo_apis/PaddleDetection_api/configs/ByteTrack_PP-YOLOE_L.yaml
new file mode 100644
index 0000000000..fce121302d
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/configs/ByteTrack_PP-YOLOE_L.yaml
@@ -0,0 +1,206 @@
+# Runtime
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+save_dir: output
+print_flops: false
+print_params: false
+log_iter: 20
+snapshot_epoch: 2
+
+# Dataset
+metric: MOT
+num_classes: 1
+
+# Detection Dataset for training
+# TrainDataset: !COCODataSet
+# dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+# anno_path: annotations/train_half.json
+# image_dir: images/train
+# data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+TrainDataset: !COCODataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/train_half.json
+ image_dir: images/train
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset: !COCODataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/val_half.json
+ image_dir: images/train
+
+TestDataset: !ImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/val_half.json
+
+# MOTDataset for MOT evaluation and inference
+EvalMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ data_root: images/half # half
+ keep_ori_im: True # set as True in DeepSORT and ByteTrack
+
+TestMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ keep_ori_im: True # set True if save visualization images or video
+
+
+# Reader
+worker_num: 4
+eval_height: &eval_height 640
+eval_width: &eval_width 640
+eval_size: &eval_size [*eval_height, *eval_width]
+
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomDistort: {}
+ - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+ - RandomCrop: {}
+ - RandomFlip: {}
+ batch_transforms:
+ - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ - PadGT: {}
+ batch_size: 8
+ shuffle: true
+ drop_last: true
+ use_shared_memory: true
+ collate_batch: true
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+TestReader:
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+
+# add MOTReader for MOT evaluation and inference, note batch_size should be 1 in MOT
+EvalMOTReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+TestMOTReader:
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+# Model
+architecture: ByteTrack
+norm_type: sync_bn
+use_ema: true
+ema_decay: 0.9998
+ema_black_list: ['proj_conv.weight']
+custom_black_list: ['reduce_mean']
+pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_300e_coco.pdparams
+depth_mult: 1.0
+width_mult: 1.0
+
+ByteTrack:
+ detector: YOLOv3 # PPYOLOe version
+ reid: None
+ tracker: JDETracker
+det_weights: https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
+# https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
+reid_weights: None
+
+YOLOv3:
+ backbone: CSPResNet
+ neck: CustomCSPPAN
+ yolo_head: PPYOLOEHead
+ post_process: ~
+
+CSPResNet:
+ layers: [3, 6, 6, 3]
+ channels: [64, 128, 256, 512, 1024]
+ return_idx: [1, 2, 3]
+ use_large_stem: True
+
+CustomCSPPAN:
+ out_channels: [768, 384, 192]
+ stage_num: 1
+ block_num: 3
+ act: 'swish'
+ spp: true
+
+PPYOLOEHead:
+ fpn_strides: [32, 16, 8]
+ grid_cell_scale: 5.0
+ grid_cell_offset: 0.5
+ static_assigner_epoch: -1
+ use_varifocal_loss: True
+ loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
+ static_assigner:
+ name: ATSSAssigner
+ topk: 9
+ assigner:
+ name: TaskAlignedAssigner
+ topk: 13
+ alpha: 1.0
+ beta: 6.0
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 100
+ score_threshold: 0.1
+ nms_threshold: 0.4
+
+# BYTETracker
+JDETracker:
+ use_byte: True
+ match_thres: 0.9
+ conf_thres: 0.2
+ low_conf_thres: 0.1
+ min_box_area: 100
+ vertical_ratio: 1.6 # for pedestrian
+
+# Optimizer
+epoch: 36
+
+LearningRate:
+ base_lr: 0.001
+ schedulers:
+ - name: CosineDecay
+ max_epochs: 43
+ - name: LinearWarmup
+ start_factor: 0.001
+ epochs: 1
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0005
+ type: L2
+
+
+# Exporting the model
+export:
+ post_process: True # Whether post-processing is included in the network when export model.
+ nms: True # Whether NMS is included in the network when export model.
+ benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
+ fuse_conv_bn: False
diff --git a/paddlex/repo_apis/PaddleDetection_api/configs/DeepSORT_PP-YOLOE_ResNet.yaml b/paddlex/repo_apis/PaddleDetection_api/configs/DeepSORT_PP-YOLOE_ResNet.yaml
new file mode 100644
index 0000000000..c710cd47e2
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/configs/DeepSORT_PP-YOLOE_ResNet.yaml
@@ -0,0 +1,208 @@
+# Runtime
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+save_dir: output
+print_flops: false
+print_params: false
+use_ema: true
+log_iter: 20
+snapshot_epoch: 2
+
+# Dataset
+metric: MOT
+num_classes: 1
+
+# Detection Dataset for training
+TrainDataset: !COCODataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/train_half.json
+ image_dir: images/train
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+
+EvalDataset: !COCODataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/val_half.json
+ image_dir: images/train
+
+TestDataset: !ImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/val_half.json
+
+
+# MOTDataset for MOT evaluation and inference
+EvalMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ data_root: images/half
+ keep_ori_im: True # set as True in DeepSORT and ByteTrack
+
+TestMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ keep_ori_im: True # set True if save visualization images or video
+
+
+# Reader
+worker_num: 4
+eval_height: &eval_height 640
+eval_width: &eval_width 640
+eval_size: &eval_size [*eval_height, *eval_width]
+
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomDistort: {}
+ - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+ - RandomCrop: {}
+ - RandomFlip: {}
+ batch_transforms:
+ - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ - PadGT: {}
+ batch_size: 8
+ shuffle: true
+ drop_last: true
+ use_shared_memory: true
+ collate_batch: true
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 2
+
+TestReader:
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+
+EvalMOTReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+TestMOTReader:
+ inputs_def:
+ image_shape: [3, 640, 640]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+# Model
+ema_black_list: ['proj_conv.weight']
+
+norm_type: sync_bn
+use_ema: true
+ema_decay: 0.9998
+depth_mult: 1.0
+width_mult: 1.0
+det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
+reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_resnet.pdparams
+
+architecture: DeepSORT
+pretrain_weights: None
+
+
+YOLOv3:
+ backbone: CSPResNet
+ neck: CustomCSPPAN
+ yolo_head: PPYOLOEHead
+ post_process: ~
+
+
+CSPResNet:
+ layers: [3, 6, 6, 3]
+ channels: [64, 128, 256, 512, 1024]
+ return_idx: [1, 2, 3]
+ use_large_stem: True
+
+
+CustomCSPPAN:
+ out_channels: [768, 384, 192]
+ stage_num: 1
+ block_num: 3
+ act: 'swish'
+ spp: true
+
+PPYOLOEHead:
+ fpn_strides: [32, 16, 8]
+ grid_cell_scale: 5.0
+ grid_cell_offset: 0.5
+ static_assigner_epoch: -1
+ use_varifocal_loss: True
+ loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
+ static_assigner:
+ name: ATSSAssigner
+ topk: 9
+ assigner:
+ name: TaskAlignedAssigner
+ topk: 13
+ alpha: 1.0
+ beta: 6.0
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 100
+ score_threshold: 0.4
+ nms_threshold: 0.6
+
+DeepSORT:
+ detector: YOLOv3 # PPYOLOe version
+ reid: ResNetEmbedding
+ tracker: DeepSORTTracker
+
+ResNetEmbedding:
+ model_name: "ResNet50"
+
+DeepSORTTracker:
+ input_size: [64, 192]
+ min_box_area: 0
+ vertical_ratio: -1
+ budget: 100
+ max_age: 70
+ n_init: 3
+ metric_type: cosine
+ matching_threshold: 0.2
+ max_iou_distance: 0.9
+ motion: KalmanFilter
+
+
+# Optimizer
+epoch: 36
+LearningRate:
+ base_lr: 0.001
+ schedulers:
+ - !CosineDecay
+ max_epochs: 43
+ - !LinearWarmup
+ start_factor: 0.001
+ epochs: 1
+
+OptimizerBuilder:
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0005
+ type: L2
+# Exporting the model
+export:
+ post_process: True # Whether post-processing is included in the network when export model.
+ nms: True # Whether NMS is included in the network when export model.
+ benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
+ fuse_conv_bn: False
\ No newline at end of file
diff --git a/paddlex/repo_apis/PaddleDetection_api/configs/FairMOT-DLA-34.yaml b/paddlex/repo_apis/PaddleDetection_api/configs/FairMOT-DLA-34.yaml
new file mode 100644
index 0000000000..f8940d07cc
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/configs/FairMOT-DLA-34.yaml
@@ -0,0 +1,171 @@
+# Runtime
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+log_iter: 20
+save_dir: output
+snapshot_epoch: 1
+print_flops: false
+print_params: false
+use_ema: true
+
+# Dataset
+metric: MOT
+num_classes: 1
+
+# for MOT training
+TrainDataset:
+ name: MOTDataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ image_lists: ['mot17.train'] #
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
+
+EvalDataset:
+ name: MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ data_root: MOT17/images/train
+ keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
+ data_fields: ['image']
+
+TestDataset:
+ name: ImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ anno_path: MOT17/annotations/val_half.json
+
+# for MOT evaluation
+# If you want to change the MOT evaluation dataset, please modify 'data_root'
+EvalMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ data_root: MOT17/images/train
+ keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
+ data_fields: ['image']
+
+# for MOT video inference
+TestMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ keep_ori_im: True # set True if save visualization images or video
+
+# Reader
+worker_num: 4
+TrainReader:
+ inputs_def:
+ image_shape: [3, 608, 1088]
+ sample_transforms:
+ - Decode: {}
+ - RGBReverse: {}
+ - AugmentHSV: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - MOTRandomAffine: {reject_outside: False}
+ - RandomFlip: {}
+ - BboxXYXY2XYWH: {}
+ - NormalizeBox: {}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]}
+ - RGBReverse: {}
+ - Permute: {}
+ batch_transforms:
+ - Gt2FairMOTTarget: {}
+ batch_size: 6
+ shuffle: True
+ drop_last: True
+ use_shared_memory: True
+
+################新加的
+# EvalReader:
+# sample_transforms:
+# - Decode: {}
+# - LetterBoxResize: {target_size: [608, 1088]}
+# - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+# - Permute: {}
+# batch_size: 1
+###################
+EvalMOTReader:
+ sample_transforms:
+ - Decode: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+
+TestMOTReader:
+ inputs_def:
+ image_shape: [3, 608, 1088]
+ sample_transforms:
+ - Decode: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+
+# Model
+architecture: FairMOT
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/fairmot_dla34_crowdhuman_pretrained.pdparams
+for_mot: True
+
+FairMOT:
+ detector: CenterNet
+ reid: FairMOTEmbeddingHead
+ loss: FairMOTLoss
+ tracker: JDETracker
+
+CenterNet:
+ backbone: DLA
+ neck: CenterNetDLAFPN
+ head: CenterNetHead
+ post_process: CenterNetPostProcess
+
+CenterNetDLAFPN:
+ down_ratio: 4
+ last_level: 5
+ out_channel: 0
+ dcn_v2: True
+ with_sge: False
+
+CenterNetHead:
+ head_planes: 256
+ prior_bias: -2.19
+ regress_ltrb: True
+ size_loss: 'L1'
+ loss_weight: {'heatmap': 1.0, 'size': 0.1, 'offset': 1.0, 'iou': 0.0}
+ add_iou: False
+
+FairMOTEmbeddingHead:
+ ch_head: 256
+ ch_emb: 128
+
+CenterNetPostProcess:
+ max_per_img: 500
+ down_ratio: 4
+ regress_ltrb: True
+
+JDETracker:
+ conf_thres: 0.4
+ tracked_thresh: 0.4
+ metric_type: cosine
+ min_box_area: 200
+ vertical_ratio: 1.6 # for pedestrian
+
+# Optimizer
+epoch: 30
+
+LearningRate:
+ base_lr: 0.0001
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [20,]
+ use_warmup: False
+
+OptimizerBuilder:
+ optimizer:
+ type: Adam
+ regularizer: NULL
+
+
+# Exporting the model
+export:
+ post_process: True # Whether post-processing is included in the network when export model.
+ nms: True # Whether NMS is included in the network when export model.
+ benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
+ fuse_conv_bn: False
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/__init__.py b/paddlex/repo_apis/PaddleDetection_api/mot/__init__.py
new file mode 100644
index 0000000000..325812d306
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/__init__.py
@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .model import MOTModel
+from .runner import MOTRunner
+from . import register
+from .official_categories import official_categories
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/config.py b/paddlex/repo_apis/PaddleDetection_api/mot/config.py
new file mode 100644
index 0000000000..622afb5e83
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/config.py
@@ -0,0 +1,505 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ...base import BaseConfig
+from ....utils.misc import abspath
+from ....utils import logging
+from ..config_helper import PPDetConfigMixin
+
+
+class MOTConfig(BaseConfig, PPDetConfigMixin):
+ """MOTConfig"""
+
+ def load(self, config_path: str):
+ """load the config from config file
+
+ Args:
+ config_path (str): the config file path.
+ """
+ dict_ = self.load_config_literally(config_path)
+ self.reset_from_dict(dict_)
+
+ def dump(self, config_path: str):
+ """dump the config
+
+ Args:
+ config_path (str): the path to save dumped config.
+ """
+ self.dump_literal_config(config_path, self._dict)
+
+ def update(self, dict_like_obj: list):
+ """update self from dict
+
+ Args:
+ dict_like_obj (list): the list of pairs that contain key and value.
+ """
+ self.update_from_dict(dict_like_obj, self._dict)
+
+ def update_dataset(
+ self,
+ dataset_path: str,
+ dataset_type: str = None,
+ *,
+ data_fields: list[str] = ['image', 'gt_bbox', 'gt_class', 'is_crowd'],
+ image_dir: str = "images/train",
+ train_anno_path: str = "annotations/train_half.json", # annotations/instance_train.json
+ val_anno_path: str = "annotations/val_half.json",
+ test_anno_path: str = "annotations/val_half.json",
+ ):
+ """update dataset settings
+
+ Args:
+ dataset_path (str): the root path fo dataset.
+ dataset_type (str, optional): the dataset type. Defaults to None.
+ data_fields (list[str], optional): the data fields in dataset. Defaults to None.
+ image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
+ train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
+ Defaults to "annotations/instance_train.json".
+ val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
+ Defaults to "annotations/instance_val.json".
+ test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
+ Defaults to "annotations/instance_val.json".
+
+ Raises:
+ ValueError: the `dataset_type` error.
+ """
+ dataset_path = abspath(dataset_path)
+ ds_cfg = {}
+ dataset_names = ['TrainDataset', 'EvalDataset', 'TestDataset', 'EvalMOTDataset', 'TestMOTDataset']
+
+ for data_name in dataset_names:
+ ds_cfg[data_name]={}
+ ds_cfg[data_name]['dataset_dir'] = dataset_path
+
+ self.update(ds_cfg)
+
+ def _make_dataset_config(
+ self,
+ dataset_root_path: str,
+ data_fields: list[str,] = None,
+ image_dir: str = "images",
+ train_anno_path: str = "annotations/instance_train.json",
+ val_anno_path: str = "annotations/instance_val.json",
+ test_anno_path: str = "annotations/instance_val.json",
+ ) -> dict:
+ """construct the dataset config that meets the format requirements
+
+ Args:
+ dataset_root_path (str): the root directory of dataset.
+ data_fields (list[str,], optional): the data field. Defaults to None.
+ image_dir (str, optional): _description_. Defaults to "images".
+ train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
+ val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
+ test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
+
+ Returns:
+ dict: the dataset config.
+ """
+
+ data_fields = (
+ ["image", "gt_bbox", "gt_class", "is_crowd"]
+ if data_fields is None
+ else data_fields
+ )
+
+ return {
+ "TrainDataset": {
+ "name": "COCODetDataset",
+ "image_dir": image_dir,
+ "anno_path": train_anno_path,
+ "dataset_dir": dataset_root_path,
+ "data_fields": data_fields,
+ },
+ "EvalDataset": {
+ "name": "COCODetDataset",
+ "image_dir": image_dir,
+ "anno_path": val_anno_path,
+ "dataset_dir": dataset_root_path,
+ },
+ "TestDataset": {
+ "name": "ImageFolder",
+ "anno_path": test_anno_path,
+ "dataset_dir": dataset_root_path,
+ },
+ }
+
+ def update_ema(
+ self,
+ use_ema: bool,
+ ema_decay: float = 0.9999,
+ ema_decay_type: str = "exponential",
+ ema_filter_no_grad: bool = True,
+ ):
+ """update EMA setting
+
+ Args:
+ use_ema (bool): whether or not to use EMA
+ ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
+ ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
+ ema_filter_no_grad (bool, optional): whether or not to filter the parameters
+ that been set to stop gradient and are not batch norm parameters. Defaults to True.
+ """
+ self.update(
+ {
+ "use_ema": use_ema,
+ "ema_decay": ema_decay,
+ "ema_decay_type": ema_decay_type,
+ "ema_filter_no_grad": ema_filter_no_grad,
+ }
+ )
+
+ def update_learning_rate(self, learning_rate: float):
+ """update learning rate
+
+ Args:
+ learning_rate (float): the learning rate value to set.
+ """
+ self.LearningRate["base_lr"] = learning_rate
+
+ def update_warmup_steps(self, warmup_steps: int):
+ """update warmup steps
+
+ Args:
+ warmup_steps (int): the warmup steps value to set.
+ """
+ schedulers = self.LearningRate["schedulers"]
+ for sch in schedulers:
+ key = "name" if "name" in sch else "_type_"
+ if sch[key] == "LinearWarmup":
+ sch["steps"] = warmup_steps
+ sch["epochs_first"] = False
+
+ def update_warmup_enable(self, use_warmup: bool):
+ """whether or not to enable learning rate warmup
+
+ Args:
+ use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
+ """
+ schedulers = self.LearningRate["schedulers"]
+ for sch in schedulers:
+ if "use_warmup" in sch:
+ sch["use_warmup"] = use_warmup
+
+ def update_cossch_epoch(self, max_epochs: int):
+ """update max epochs of cosine learning rate scheduler
+
+ Args:
+ max_epochs (int): the max epochs value.
+ """
+ schedulers = self.LearningRate["schedulers"]
+ for sch in schedulers:
+ key = "name" if "name" in sch else "_type_"
+ if sch[key] == "CosineDecay":
+ sch["max_epochs"] = max_epochs
+
+ def update_milestone(self, milestones: list[int]):
+ """update milstone of `PiecewiseDecay` learning scheduler
+
+ Args:
+ milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
+ """
+ schedulers = self.LearningRate["schedulers"]
+ for sch in schedulers:
+ key = "name" if "name" in sch else "_type_"
+ if sch[key] == "PiecewiseDecay":
+ sch["milestones"] = milestones
+
+ def update_batch_size(self, batch_size: int, mode: str = "train"):
+ """update batch size setting
+
+ Args:
+ batch_size (int): the batch size number to set.
+ mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
+ Defaults to 'train'.
+
+ Raises:
+ ValueError: mode error.
+ """
+ assert mode in (
+ "train",
+ "eval",
+ "test",
+ ), "mode ({}) should be train, eval or test".format(mode)
+ if mode == "train":
+ self.TrainReader["batch_size"] = batch_size
+ elif mode == "eval":
+ self.EvalReader["batch_size"] = batch_size
+ else:
+ self.TestReader["batch_size"] = batch_size
+
+ def update_epochs(self, epochs: int):
+ """update epochs setting
+
+ Args:
+ epochs (int): the epochs number value to set
+ """
+ self.update({"epoch": epochs})
+
+ def update_device(self, device_type: str):
+ """update device setting
+
+ Args:
+ device (str): the running device to set
+ """
+ if device_type.lower() == "gpu":
+ self["use_gpu"] = True
+ elif device_type.lower() == "xpu":
+ self["use_xpu"] = True
+ self["use_gpu"] = False
+ elif device_type.lower() == "npu":
+ self["use_npu"] = True
+ self["use_gpu"] = False
+ elif device_type.lower() == "mlu":
+ self["use_mlu"] = True
+ self["use_gpu"] = False
+ else:
+ assert device_type.lower() == "cpu"
+ self["use_gpu"] = False
+
+ def update_save_dir(self, save_dir: str):
+ """update directory to save outputs
+
+ Args:
+ save_dir (str): the directory to save outputs.
+ """
+ self["save_dir"] = abspath(save_dir)
+
+ def update_log_interval(self, log_interval: int):
+ """update log interval(steps)
+
+ Args:
+ log_interval (int): the log interval value to set.
+ """
+ self.update({"log_iter": log_interval})
+
+ def update_eval_interval(self, eval_interval: int):
+ """update eval interval(epochs)
+
+ Args:
+ eval_interval (int): the eval interval value to set.
+ """
+ self.update({"snapshot_epoch": eval_interval})
+
+ def update_save_interval(self, save_interval: int):
+ """update eval interval(epochs)
+
+ Args:
+ save_interval (int): the save interval value to set.
+ """
+ self.update({"snapshot_epoch": save_interval})
+
+ def update_log_ranks(self, device):
+ """update log ranks
+
+ Args:
+ device (str): the running device to set
+ """
+ log_ranks = device.split(":")[1]
+ self.update({"log_ranks": log_ranks})
+
+ def update_print_mem_info(self, print_mem_info: bool):
+ """setting print memory info"""
+ assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
+ self.update({"print_mem_info": f"{print_mem_info}"})
+
+ def update_shared_memory(self, shared_memeory: bool):
+ """update shared memory setting of train and eval dataloader
+
+ Args:
+ shared_memeory (bool): whether or not to use shared memory
+ """
+ assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
+ self.update({"print_mem_info": f"{shared_memeory}"})
+
+ def update_shuffle(self, shuffle: bool):
+ """update shuffle setting of train and eval dataloader
+
+ Args:
+ shuffle (bool): whether or not to shuffle the data
+ """
+ assert isinstance(shuffle, bool), "shuffle should be a bool"
+ self.update({"TrainReader": {"shuffle": shuffle}})
+ self.update({"EvalReader": {"shuffle": shuffle}})
+
+ def update_weights(self, weight_path: str):
+ """update model weight
+
+ Args:
+ weight_path (str): the path to weight file of model.
+ """
+ self["weights"] = weight_path
+
+ def update_pretrained_weights(self, pretrain_weights: str):
+ """update pretrained weight path
+
+ Args:
+ pretrained_model (str): the local path or url of pretrained weight file to set.
+ """
+ if not pretrain_weights.startswith(
+ "http://"
+ ) and not pretrain_weights.startswith("https://"):
+ pretrain_weights = abspath(pretrain_weights)
+ self["pretrain_weights"] = pretrain_weights
+
+ def update_num_class(self, num_classes: int):
+ """update classes number
+
+ Args:
+ num_classes (int): the classes number value to set.
+ """
+ self["num_classes"] = num_classes
+ if 'CenterNet' in self.model_name:
+ for i in range(len(self['TrainReader']['sample_transforms'])):
+ if 'Gt2CenterNetTarget' in self['TrainReader']['sample_transforms'][i].keys():
+ self['TrainReader']['sample_transforms'][i]['Gt2CenterNetTarget']['num_classes'] = num_classes
+
+
+ def update_random_size(self, randomsize: list[list[int, int]]):
+ """update `target_size` of `BatchRandomResize` op in TestReader
+
+ Args:
+ randomsize (list[list[int, int]]): the list of different size scales.
+ """
+ self.TestReader["batch_transforms"]["BatchRandomResize"][
+ "target_size"
+ ] = randomsize
+
+ def update_num_workers(self, num_workers: int):
+ """update workers number of train and eval dataloader
+
+ Args:
+ num_workers (int): the value of train and eval dataloader workers number to set.
+ """
+ self["worker_num"] = num_workers
+
+ def _recursively_set(self, config: dict, update_dict: dict):
+ """recursively set config
+
+ Args:
+ config (dict): the original config.
+ update_dict (dict): to be updated paramenters and its values
+
+ Example:
+ self._recursively_set(self.HybridEncoder, {'encoder_layer': {'dim_feedforward': 2048}})
+ """
+ assert isinstance(update_dict, dict)
+ for key in update_dict:
+ if key not in config:
+ logging.info(f"A new filed of config to set found: {repr(key)}.")
+ config[key] = update_dict[key]
+ elif not isinstance(update_dict[key], dict):
+ config[key] = update_dict[key]
+ else:
+ self._recursively_set(config[key], update_dict[key])
+
+ def update_static_assigner_epochs(self, static_assigner_epochs: int):
+ """update static assigner epochs value
+
+ Args:
+ static_assigner_epochs (int): the value of static assigner epochs
+ """
+ assert "PicoHeadV2" in self
+ self.PicoHeadV2["static_assigner_epoch"] = static_assigner_epochs
+
+ def update_HybridEncoder(self, update_dict: dict):
+ """update the HybridEncoder neck setting
+
+ Args:
+ update_dict (dict): the HybridEncoder setting.
+ """
+ assert "HybridEncoder" in self
+ self._recursively_set(self.HybridEncoder, update_dict)
+
+ def get_epochs_iters(self) -> int:
+ """get epochs
+
+ Returns:
+ int: the epochs value, i.e., `Global.epochs` in config.
+ """
+ return self.epoch
+
+ def get_log_interval(self) -> int:
+ """get log interval(steps)
+
+ Returns:
+ int: the log interval value, i.e., `Global.print_batch_step` in config.
+ """
+ self.log_iter
+
+ def get_eval_interval(self) -> int:
+ """get eval interval(epochs)
+
+ Returns:
+ int: the eval interval value, i.e., `Global.eval_interval` in config.
+ """
+ self.snapshot_epoch
+
+ def get_save_interval(self) -> int:
+ """get save interval(epochs)
+
+ Returns:
+ int: the save interval value, i.e., `Global.save_interval` in config.
+ """
+ self.snapshot_epoch
+
+ def get_learning_rate(self) -> float:
+ """get learning rate
+
+ Returns:
+ float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
+ """
+ return self.LearningRate["base_lr"]
+
+ def get_batch_size(self, mode="train") -> int:
+ """get batch size
+
+ Args:
+ mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
+ Defaults to 'train'.
+
+ Returns:
+ int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
+ """
+ if mode == "train":
+ return self.TrainReader["batch_size"]
+ elif mode == "eval":
+ return self.EvalReader["batch_size"]
+ elif mode == "test":
+ return self.TestReader["batch_size"]
+ else:
+ raise (f"Unknown mode: {repr(mode)}")
+
+ def get_qat_epochs_iters(self) -> int:
+ """get qat epochs
+
+ Returns:
+ int: the epochs value.
+ """
+ return self.epoch // 2.0
+
+ def get_qat_learning_rate(self) -> float:
+ """get qat learning rate
+
+ Returns:
+ float: the learning rate value.
+ """
+ return self.LearningRate["base_lr"] // 2.0
+
+ def get_train_save_dir(self) -> str:
+ """get the directory to save output
+
+ Returns:
+ str: the directory to save output
+ """
+ return self.save_dir
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/model.py b/paddlex/repo_apis/PaddleDetection_api/mot/model.py
new file mode 100644
index 0000000000..70ee585c54
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/model.py
@@ -0,0 +1,433 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+
+from ...base import BaseModel
+from ...base.utils.arg import CLIArgument
+from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
+from ....utils.misc import abspath
+from ....utils import logging
+
+from .config import MOTConfig
+from .official_categories import official_categories
+
+
+class MOTModel(BaseModel):
+ """Object Detection Model"""
+
+ def train(
+ self,
+ batch_size: int = None,
+ learning_rate: float = None,
+ epochs_iters: int = None,
+ ips: str = None,
+ device: str = "gpu",
+ resume_path: str = None,
+ dy2st: bool = False,
+ amp: str = "OFF",
+ num_workers: int = None,
+ use_vdl: bool = True,
+ save_dir: str = None,
+ **kwargs,
+ ) -> CompletedProcess:
+ """train self
+
+ Args:
+ batch_size (int, optional): the train batch size value. Defaults to None.
+ learning_rate (float, optional): the train learning rate value. Defaults to None.
+ epochs_iters (int, optional): the train epochs value. Defaults to None.
+ ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
+ device (str, optional): the running device. Defaults to 'gpu'.
+ resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
+ to None. Defaults to None.
+ dy2st (bool, optional): Enable dynamic to static. Defaults to False.
+ amp (str, optional): the amp settings. Defaults to 'OFF'.
+ num_workers (int, optional): the workers number. Defaults to None.
+ use_vdl (bool, optional): enable VisualDL. Defaults to True.
+ save_dir (str, optional): the directory path to save train output. Defaults to None.
+
+ Returns:
+ CompletedProcess: the result of training subprocess execution.
+ """
+ config = self.config.copy()
+ cli_args = []
+
+ if batch_size is not None:
+ config.update_batch_size(batch_size, "train")
+ if learning_rate is not None:
+ config.update_learning_rate(learning_rate)
+ if epochs_iters is not None:
+ config.update_epochs(epochs_iters)
+ config.update_cossch_epoch(epochs_iters)
+ device_type, _ = parse_device(device)
+ config.update_device(device_type)
+ if resume_path is not None:
+ assert resume_path.endswith(
+ ".pdparams"
+ ), "resume_path should be endswith .pdparam"
+ resume_dir = resume_path[0:-9]
+ cli_args.append(CLIArgument("--resume", resume_dir))
+ if dy2st:
+ cli_args.append(CLIArgument("--to_static"))
+ if num_workers is not None:
+ config.update_num_workers(num_workers)
+ if save_dir is None:
+ save_dir = abspath(config.get_train_save_dir())
+ else:
+ save_dir = abspath(save_dir)
+ config.update_save_dir(save_dir)
+ if use_vdl:
+ cli_args.append(CLIArgument("--use_vdl", use_vdl))
+ cli_args.append(CLIArgument("--vdl_log_dir", save_dir))
+
+ do_eval = kwargs.pop("do_eval", True)
+ enable_ce = kwargs.pop("enable_ce", None)
+
+ profile = kwargs.pop("profile", None)
+ if profile is not None:
+ cli_args.append(CLIArgument("--profiler_options", profile))
+
+ # Benchmarking mode settings
+ benchmark = kwargs.pop("benchmark", None)
+ if benchmark is not None:
+ envs = benchmark.get("env", None)
+ amp = benchmark.get("amp", None)
+ do_eval = benchmark.get("do_eval", False)
+ num_workers = benchmark.get("num_workers", None)
+ config.update_log_ranks(device)
+ config.update_shuffle(benchmark.get("shuffle", False))
+ config.update_shared_memory(benchmark.get("shared_memory", True))
+ config.update_print_mem_info(benchmark.get("print_mem_info", True))
+ if num_workers is not None:
+ config.update_num_workers(num_workers)
+ if amp == "O1":
+ # TODO: ppdet only support ampO1
+ cli_args.append(CLIArgument("--amp"))
+ if envs is not None:
+ for env_name, env_value in envs.items():
+ os.environ[env_name] = str(env_value)
+ # set seed to 0 for benchmark mode by enable_ce
+ cli_args.append(CLIArgument("--enable_ce", True))
+ else:
+ if amp != "OFF" and amp is not None:
+ # TODO: consider amp is O1 or O2 in ppdet
+ cli_args.append(CLIArgument("--amp"))
+ if enable_ce is not None:
+ cli_args.append(CLIArgument("--enable_ce", enable_ce))
+
+ # PDX related settings
+ if device_type in ["npu", "xpu", "mlu"]:
+ uniform_output_enabled = False
+ else:
+ uniform_output_enabled = True
+ config.update({"uniform_output_enabled": uniform_output_enabled})
+ config.update({"pdx_model_name": self.name})
+ hpi_config_path = self.model_info.get("hpi_config_path", None)
+ if hpi_config_path:
+ hpi_config_path = hpi_config_path.as_posix()
+ config.update({"hpi_config_path": hpi_config_path})
+
+ self._assert_empty_kwargs(kwargs)
+
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ return self.runner.train(
+ config_path, cli_args, device, ips, save_dir, do_eval=do_eval
+ )
+
+ def evaluate(
+ self,
+ weight_path: str,
+ batch_size: int = None,
+ ips: bool = None,
+ device: bool = "gpu",
+ amp: bool = "OFF",
+ num_workers: int = None,
+ **kwargs,
+ ) -> CompletedProcess:
+ """evaluate self using specified weight
+
+ Args:
+ weight_path (str): the path of model weight file to be evaluated.
+ batch_size (int, optional): the batch size value in evaluating. Defaults to None.
+ ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
+ device (str, optional): the running device. Defaults to 'gpu'.
+ amp (str, optional): the AMP setting. Defaults to 'OFF'.
+ num_workers (int, optional): the workers number in evaluating. Defaults to None.
+
+ Returns:
+ CompletedProcess: the result of evaluating subprocess execution.
+ """
+ config = self.config.copy()
+ cli_args = []
+
+ weight_path = abspath(weight_path)
+ config.update_weights(weight_path)
+ if batch_size is not None:
+ config.update_batch_size(batch_size, "eval")
+ device_type, device_ids = parse_device(device)
+ if len(device_ids) > 1:
+ raise ValueError(
+ f"multi-{device_type} evaluation is not supported. Please use a single {device_type}."
+ )
+ config.update_device(device_type)
+ if amp != "OFF":
+ # TODO: consider amp is O1 or O2 in ppdet
+ cli_args.append(CLIArgument("--amp"))
+ if num_workers is not None:
+ config.update_num_workers(num_workers)
+
+ self._assert_empty_kwargs(kwargs)
+
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ cp = self.runner.evaluate(config_path, cli_args, device, ips, config.dict['architecture'])
+ return cp
+
+ def predict(
+ self,
+ input_path: str,
+ weight_path: str,
+ device: str = "gpu",
+ save_dir: str = None,
+ **kwargs,
+ ) -> CompletedProcess:
+ """predict using specified weight
+
+ Args:
+ weight_path (str): the path of model weight file used to predict.
+ input_path (str): the path of image file to be predicted.
+ device (str, optional): the running device. Defaults to 'gpu'.
+ save_dir (str, optional): the directory path to save predict output. Defaults to None.
+
+ Returns:
+ CompletedProcess: the result of predicting subprocess execution.
+ """
+ config = self.config.copy()
+ cli_args = []
+
+ input_path = abspath(input_path)
+ if os.path.isfile(input_path):
+ cli_args.append(CLIArgument("--infer_img", input_path))
+ else:
+ cli_args.append(CLIArgument("--infer_dir", input_path))
+ if "infer_list" in kwargs:
+ infer_list = abspath(kwargs.get("infer_list"))
+ cli_args.append(CLIArgument("--infer_list", infer_list))
+ if "visualize" in kwargs:
+ cli_args.append(CLIArgument("--visualize", kwargs["visualize"]))
+ if "save_results" in kwargs:
+ cli_args.append(CLIArgument("--save_results", kwargs["save_results"]))
+ if "save_threshold" in kwargs:
+ cli_args.append(CLIArgument("--save_threshold", kwargs["save_threshold"]))
+ if "rtn_im_file" in kwargs:
+ cli_args.append(CLIArgument("--rtn_im_file", kwargs["rtn_im_file"]))
+ weight_path = abspath(weight_path)
+ config.update_weights(weight_path)
+ device_type, _ = parse_device(device)
+ config.update_device(device_type)
+ if save_dir is not None:
+ save_dir = abspath(save_dir)
+ cli_args.append(CLIArgument("--output_dir", save_dir))
+
+ self._assert_empty_kwargs(kwargs)
+
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ return self.runner.predict(config_path, cli_args, device)
+
+ def export(self, weight_path: str, save_dir: str, **kwargs) -> CompletedProcess:
+ """export the dynamic model to static model
+
+ Args:
+ weight_path (str): the model weight file path that used to export.
+ save_dir (str): the directory path to save export output.
+
+ Returns:
+ CompletedProcess: the result of exporting subprocess execution.
+ """
+ config = self.config.copy()
+ cli_args = []
+
+ device = kwargs.pop("device", None)
+ if device:
+ device_type, _ = parse_device(device)
+ config.update_device(device_type)
+
+ if not weight_path.startswith("http"):
+ weight_path = abspath(weight_path)
+ config.update_weights(weight_path)
+ save_dir = abspath(save_dir)
+ cli_args.append(CLIArgument("--output_dir", save_dir))
+ input_shape = kwargs.pop("input_shape", None)
+ if input_shape is not None:
+ cli_args.append(
+ CLIArgument("-o", f"TestReader.inputs_def.image_shape={input_shape}")
+ )
+
+ use_trt = kwargs.pop("use_trt", None)
+ if use_trt is not None:
+ cli_args.append(CLIArgument("-o", f"trt={bool(use_trt)}"))
+
+ exclude_nms = kwargs.pop("exclude_nms", None)
+ if exclude_nms is not None:
+ cli_args.append(CLIArgument("-o", f"exclude_nms={bool(exclude_nms)}"))
+
+ # PDX related settings
+ config.update({"pdx_model_name": self.name})
+ hpi_config_path = self.model_info.get("hpi_config_path", None)
+ if hpi_config_path:
+ hpi_config_path = hpi_config_path.as_posix()
+ config.update({"hpi_config_path": hpi_config_path})
+
+ if self.name in official_categories.keys():
+ anno_val_file = abspath(
+ os.path.join(
+ config.TestDataset["dataset_dir"], config.TestDataset["anno_path"]
+ )
+ )
+ if anno_val_file == None or (not os.path.isfile(anno_val_file)):
+ categories = official_categories[self.name]
+ temp_anno = {"images": [], "annotations": [], "categories": categories}
+ with self._create_new_val_json_file() as anno_file:
+ json.dump(temp_anno, open(anno_file, "w"))
+ config.update(
+ {"TestDataset": {"dataset_dir": "", "anno_path": anno_file}}
+ )
+ logging.warning(
+ f"{self.name} does not have validate annotations, use {anno_file} default instead."
+ )
+ self._assert_empty_kwargs(kwargs)
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ return self.runner.export(config_path, cli_args, None)
+
+ self._assert_empty_kwargs(kwargs)
+
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ return self.runner.export(config_path, cli_args, None)
+
+ def infer(
+ self,
+ model_dir: str,
+ input_path: str,
+ device: str = "gpu",
+ save_dir: str = None,
+ **kwargs,
+ ):
+ """predict image using infernece model
+
+ Args:
+ model_dir (str): the directory path of inference model files that would use to predict.
+ input_path (str): the path of image that would be predict.
+ device (str, optional): the running device. Defaults to 'gpu'.
+ save_dir (str, optional): the directory path to save output. Defaults to None.
+
+ Returns:
+ CompletedProcess: the result of infering subprocess execution.
+ """
+ model_dir = abspath(model_dir)
+ input_path = abspath(input_path)
+ if save_dir is not None:
+ save_dir = abspath(save_dir)
+
+ cli_args = []
+ cli_args.append(CLIArgument("--model_dir", model_dir))
+ cli_args.append(CLIArgument("--image_file", input_path))
+ if save_dir is not None:
+ cli_args.append(CLIArgument("--output_dir", save_dir))
+ device_type, _ = parse_device(device)
+ cli_args.append(CLIArgument("--device", device_type))
+
+ self._assert_empty_kwargs(kwargs)
+
+ return self.runner.infer(cli_args, device)
+
+ def compression(
+ self,
+ weight_path: str,
+ batch_size: int = None,
+ learning_rate: float = None,
+ epochs_iters: int = None,
+ device: str = None,
+ use_vdl: bool = True,
+ save_dir: str = None,
+ **kwargs,
+ ) -> CompletedProcess:
+ """compression model
+
+ Args:
+ weight_path (str): the path to weight file of model.
+ batch_size (int, optional): the batch size value of compression training. Defaults to None.
+ learning_rate (float, optional): the learning rate value of compression training. Defaults to None.
+ epochs_iters (int, optional): the epochs or iters of compression training. Defaults to None.
+ device (str, optional): the device to run compression training. Defaults to 'gpu'.
+ use_vdl (bool, optional): whether or not to use VisualDL. Defaults to True.
+ save_dir (str, optional): the directory to save output. Defaults to None.
+
+ Returns:
+ CompletedProcess: the result of compression subprocess execution.
+ """
+ weight_path = abspath(weight_path)
+ if save_dir is None:
+ save_dir = self.config["save_dir"]
+ save_dir = abspath(save_dir)
+
+ config = self.config.copy()
+ cps_config = DetConfig(
+ self.name, config_path=self.model_info["auto_compression_config_path"]
+ )
+ train_cli_args = []
+ export_cli_args = []
+
+ cps_config.update_pretrained_weights(weight_path)
+ if batch_size is not None:
+ cps_config.update_batch_size(batch_size, "train")
+ if learning_rate is not None:
+ cps_config.update_learning_rate(learning_rate)
+ if epochs_iters is not None:
+ cps_config.update_epochs(epochs_iters)
+ if device is not None:
+ device_type, _ = parse_device(device)
+ config.update_device(device_type)
+ if save_dir is not None:
+ save_dir = abspath(config.get_train_save_dir())
+ else:
+ save_dir = abspath(save_dir)
+ cps_config.update_save_dir(save_dir)
+ if use_vdl:
+ train_cli_args.append(CLIArgument("--use_vdl", use_vdl))
+ train_cli_args.append(CLIArgument("--vdl_log_dir", save_dir))
+
+ export_cli_args.append(
+ CLIArgument("--output_dir", os.path.join(save_dir, "export"))
+ )
+
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ # TODO: refactor me
+ cps_config_path = config_path[0:-4] + "_compression" + config_path[-4:]
+ cps_config.dump(cps_config_path)
+ train_cli_args.append(CLIArgument("--slim_config", cps_config_path))
+ export_cli_args.append(CLIArgument("--slim_config", cps_config_path))
+
+ self._assert_empty_kwargs(kwargs)
+
+ self.runner.compression(
+ config_path, train_cli_args, export_cli_args, device, save_dir
+ )
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/official_categories.py b/paddlex/repo_apis/PaddleDetection_api/mot/official_categories.py
new file mode 100644
index 0000000000..7ae51cd0bc
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/official_categories.py
@@ -0,0 +1,14 @@
+official_categories = {
+'PP-YOLOE-L_human': [{"name": "pedestrian", "id": 0}],
+'PP-YOLOE-S_human': [{"name": "pedestrian", "id": 0}],
+'PP-YOLOE-S_vehicle': [{"name": "vehicle", "id": 0}],
+'PP-YOLOE-L_vehicle': [{"name": "vehicle", "id": 0}],
+'PP-ShiTuV2_det': [{"name": "mainbody", "id": 0}],
+'PicoDet_layout_1x': [{"name": "Text", "id": 0}, {"name": "Title", "id": 1}, {"name": "List", "id": 2}, {"name": "Table", "id": 3}, {"name": "Figure", "id": 4}],
+'PicoDet-L_layout_3cls': [{"name": "image", "id": 0}, {"name": "table", "id": 1}, {"name": "seal", "id": 2}],
+'RT-DETR-H_layout_3cls': [{"name": "image", "id": 0}, {"name": "table", "id": 1}, {"name": "seal", "id": 2}],
+'RT-DETR-H_layout_17cls': [{"name": "paragraph_title", "id": 0}, {"name": "image", "id": 1}, {"name": "text", "id": 2}, {"name": "number", "id": 3}, {"name": "abstract", "id": 4}, {"name": "content", "id": 5}, {"name": "figure_title", "id": 6}, {"name": "formula", "id": 7}, {"name": "table", "id": 8}, {"name": "tabke_title", "id": 9}, {"name":"reference", "id": 10}, {"name": "doc_title", "id": 11}, {"name": "footnote", "id": 12}, {"name": "header", "id": 13}, {"name": "algorithm", "id": 14}, {"name": "footer", "id": 15}, {"name": "seal", "id": 16}],
+'PP-YOLOE_plus_SOD-S': [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
+'PP-YOLOE_plus_SOD-L': [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
+'PP-YOLOE_plus_SOD-largesize-L': [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
+}
\ No newline at end of file
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/register.py b/paddlex/repo_apis/PaddleDetection_api/mot/register.py
new file mode 100644
index 0000000000..9e8958c210
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/register.py
@@ -0,0 +1,86 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+from pathlib import Path
+
+from ...base.register import register_model_info, register_suite_info
+from .model import MOTModel
+from .config import MOTConfig
+from .runner import MOTRunner
+
+REPO_ROOT_PATH = os.environ.get("PADDLE_PDX_PADDLEDETECTION_PATH")
+PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
+HPI_CONFIG_DIR = Path(__file__).parent.parent.parent.parent / "utils" / "hpi_configs"
+
+register_suite_info(
+ {
+ "suite_name": "MOT",
+ "model": MOTModel,
+ "runner": MOTRunner,
+ "config": MOTConfig,
+ "runner_root_path": REPO_ROOT_PATH,
+ }
+)
+
+################ Models Using Universal Config ################
+
+register_model_info(
+ {
+ "model_name": "ByteTrack_PP-YOLOE_L",
+ "suite": "MOT",
+ "config_path": osp.join(PDX_CONFIG_DIR, "ByteTrack_PP-YOLOE_L.yaml"),
+ "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+ "supported_dataset_types": ["COCODetDataset"],
+ "supported_train_opts": {
+ "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+ "dy2st": False,
+ "amp": ["OFF"],
+ },
+ "hpi_config_path": HPI_CONFIG_DIR / "ByteTrack_PP-YOLOE_L.yaml",
+ }
+)
+
+register_model_info(
+ {
+ "model_name": "DeepSORT_PP-YOLOE_ResNet",
+ "suite": "MOT",
+ "config_path": osp.join(PDX_CONFIG_DIR, "DeepSORT_PP-YOLOE_ResNet.yaml"),
+ "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+ "supported_dataset_types": ["COCODetDataset"],
+ "supported_train_opts": {
+ "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+ "dy2st": False,
+ "amp": ["OFF"],
+ },
+ "hpi_config_path": HPI_CONFIG_DIR / "DeepSORT_PP-YOLOE_ResNet.yaml",
+ }
+)
+
+register_model_info(
+ {
+ "model_name": "FairMOT-DLA-34",
+ "suite": "MOT",
+ "config_path": osp.join(PDX_CONFIG_DIR, "FairMOT-DLA-34.yaml"),
+ "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+ "supported_dataset_types": ["MOTDataSet"],
+ "supported_train_opts": {
+ "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+ "dy2st": False,
+ "amp": ["OFF"],
+ },
+ "hpi_config_path": HPI_CONFIG_DIR / "FairMOT-DLA-34.yaml",
+ }
+)
\ No newline at end of file
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/runner.py b/paddlex/repo_apis/PaddleDetection_api/mot/runner.py
new file mode 100644
index 0000000000..4e502bdc37
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/runner.py
@@ -0,0 +1,267 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import tempfile
+
+from ...base import BaseRunner
+from ...base.utils.arg import CLIArgument, gather_opts_args
+from ...base.utils.subprocess import CompletedProcess
+
+
+class MOTRunner(BaseRunner):
+ """MOTRunner"""
+
+ def train(
+ self,
+ config_path: str,
+ cli_args: list,
+ device: str,
+ ips: str,
+ save_dir: str,
+ do_eval=True,
+ ) -> CompletedProcess:
+ """train model
+
+ Args:
+ config_path (str): the config file path used to train.
+ cli_args (list): the additional parameters.
+ device (str): the training device.
+ ips (str): the ip addresses of nodes when using distribution.
+ save_dir (str): the directory path to save training output.
+ do_eval (bool, optional): whether or not to evaluate model during training. Defaults to True.
+
+ Returns:
+ CompletedProcess: the result of training subprocess execution.
+ """
+ args, env = self.distributed(device, ips, log_dir=save_dir)
+ cli_args = self._gather_opts_args(cli_args)
+ cmd = [*args, "tools/train.py"]
+ if do_eval:
+ cmd.append("--eval")
+ cmd.extend(["--config", config_path, *cli_args])
+ return self.run_cmd(
+ cmd,
+ env=env,
+ switch_wdir=True,
+ echo=True,
+ silent=False,
+ capture_output=True,
+ log_path=self._get_train_log_path(save_dir),
+ )
+
+
+
+ def evaluate(
+ self, config_path: str, cli_args: list, device: str, ips: str, arch: str
+ ) -> CompletedProcess:
+ """run model evaluating
+
+ Args:
+ config_path (str): the config file path used to evaluate.
+ cli_args (list): the additional parameters.
+ device (str): the evaluating device.
+ ips (str): the ip addresses of nodes when using distribution.
+
+ Returns:
+ CompletedProcess: the result of evaluating subprocess execution.
+ """
+ args, env = self.distributed(device, ips)
+ cli_args = self._gather_opts_args(cli_args)
+
+ cmd = [*args, "tools/eval_mot.py", "--config", config_path, "--scaled=True", *cli_args] # "--training_weights"
+
+ if arch != 'DeepSORT':
+ """
+ deepsort has no need of training, so there is no training_weights trainined by yourself.
+ """
+ cmd.append("--training_weights")
+
+ cp = self.run_cmd(
+ cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
+ )
+ if cp.returncode == 0:
+ metric_dict = _extract_mot_eval_metrics(cp.stdout)
+ cp.metrics = metric_dict
+ return cp
+
+ def predict(
+ self, config_path: str, cli_args: list, device: str
+ ) -> CompletedProcess:
+ """run predicting using dynamic mode
+
+ Args:
+ config_path (str): the config file path used to predict.
+ cli_args (list): the additional parameters.
+ device (str): unused.
+
+ Returns:
+ CompletedProcess: the result of predicting subprocess execution.
+ """
+ # `device` unused
+ cli_args = self._gather_opts_args(cli_args)
+ cmd = [self.python, "tools/infer.py", "-c", config_path, *cli_args]
+ return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+
+ def export(self, config_path: str, cli_args: list, device: str) -> CompletedProcess:
+ """run exporting
+
+ Args:
+ config_path (str): the path of config file used to export.
+ cli_args (list): the additional parameters.
+ device (str): unused.
+ save_dir (str, optional): the directory path to save exporting output. Defaults to None.
+
+ Returns:
+ CompletedProcess: the result of exporting subprocess execution.
+ """
+ # `device` unused
+ cli_args = self._gather_opts_args(cli_args)
+ cmd = [
+ self.python,
+ "tools/export_model.py",
+ "--for_fd",
+ "-c",
+ config_path,
+ *cli_args,
+ ]
+
+ cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+ return cp
+
+ def infer(self, cli_args: list, device: str) -> CompletedProcess:
+ """run predicting using inference model
+
+ Args:
+ cli_args (list): the additional parameters.
+ device (str): unused.
+
+ Returns:
+ CompletedProcess: the result of infering subprocess execution.
+ """
+ # `device` unused
+ cmd = [self.python, "deploy/python/infer.py", "--use_fd_format", *cli_args]
+ return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+
+ def compression(
+ self,
+ config_path: str,
+ train_cli_args: list,
+ export_cli_args: list,
+ device: str,
+ train_save_dir: str,
+ ) -> CompletedProcess:
+ """run compression model
+
+ Args:
+ config_path (str): the path of config file used to predict.
+ train_cli_args (list): the additional training parameters.
+ export_cli_args (list): the additional exporting parameters.
+ device (str): the running device.
+ train_save_dir (str): the directory path to save output.
+
+ Returns:
+ CompletedProcess: the result of compression subprocess execution.
+ """
+ args, env = self.distributed(device, log_dir=train_save_dir)
+ train_cli_args = self._gather_opts_args(train_cli_args)
+ cmd = [*args, "tools/train.py", "-c", config_path, *train_cli_args]
+ cp_train = self.run_cmd(
+ cmd,
+ env=env,
+ switch_wdir=True,
+ echo=True,
+ silent=False,
+ capture_output=True,
+ log_path=self._get_train_log_path(train_save_dir),
+ )
+
+ cps_weight_path = os.path.join(train_save_dir, "model_final")
+
+ export_cli_args.append(CLIArgument("-o", f"weights={cps_weight_path}"))
+ export_cli_args = self._gather_opts_args(export_cli_args)
+ cmd = [
+ self.python,
+ "tools/export_model.py",
+ "--for_fd",
+ "-c",
+ config_path,
+ *export_cli_args,
+ ]
+ cp_export = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+
+ return cp_train, cp_export
+
+ def _gather_opts_args(self, args):
+ """_gather_opts_args"""
+ return gather_opts_args(args, "-o")
+
+
+def _extract_eval_metrics(stdout):
+ """extract evaluation metrics from training log
+
+ Args:
+ stdout (str): the training log
+
+ Returns:
+ dict: the training metric
+ """
+ import re
+
+ pattern = r".*\(AP\)\s*@\[\s*IoU=0\.50:0\.95\s*\|\s*area=\s*all\s\|\smaxDets=\s*\d+\s\]\s*=\s*[0-1]?\.[0-9]{3}$"
+ key = "AP"
+
+ metric_dict = dict()
+ pattern = re.compile(pattern)
+ # TODO: Use lazy version to make it more efficient
+ lines = stdout.splitlines()
+ metric_dict[key] = 0
+ for line in lines:
+ match = pattern.search(line)
+ if match:
+ metric_dict[key] = float(match.group(0)[-5:])
+
+ return metric_dict
+
+
+def _extract_mot_eval_metrics(stdout):
+ """extract evaluation metrics from training log
+
+ Args:
+ stdout (str): the training log
+
+ Returns:
+ dict: the training metric
+ """
+ import re
+
+ pattern = r"OVERALL"
+ key = "MOTA"
+
+ metric_dict = dict()
+ pattern = re.compile(pattern)
+ # TODO: Use lazy version to make it more efficient
+ lines = stdout.splitlines()
+ metric_dict[key] = 0
+ for line in lines:
+ match = pattern.search(line)
+ if match:
+ outs = line.split(' ')
+ outs = [i for i in outs if i != ""]
+ metric_dict[key] = float(outs[-5].split('%')[0])/100
+
+ print(metric_dict)
+
+ return metric_dict
diff --git a/paddlex/utils/hpi_configs/ByteTrack_PP-YOLOE_L.yaml b/paddlex/utils/hpi_configs/ByteTrack_PP-YOLOE_L.yaml
new file mode 100644
index 0000000000..37ce8e69b2
--- /dev/null
+++ b/paddlex/utils/hpi_configs/ByteTrack_PP-YOLOE_L.yaml
@@ -0,0 +1,65 @@
+Hpi:
+ backend_config:
+ onnx_runtime:
+ cpu_num_threads: 8
+ openvino:
+ cpu_num_threads: 8
+ paddle_infer:
+ cpu_num_threads: 8
+ enable_log_info: false
+ paddle_tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ enable_log_info: false
+ max_batch_size: null
+ tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ max_batch_size: null
+ selected_backends:
+ cpu: onnx_runtime
+ gpu: tensorrt
+ supported_backends:
+ cpu:
+ - paddle_infer
+ - openvino
+ - onnx_runtime
+ gpu:
+ - paddle_infer
+ - paddle_tensorrt
+ - onnx_runtime
+ - tensorrt
diff --git a/paddlex/utils/hpi_configs/DeepSORT_PP-YOLOE_ResNet.yaml b/paddlex/utils/hpi_configs/DeepSORT_PP-YOLOE_ResNet.yaml
new file mode 100644
index 0000000000..37ce8e69b2
--- /dev/null
+++ b/paddlex/utils/hpi_configs/DeepSORT_PP-YOLOE_ResNet.yaml
@@ -0,0 +1,65 @@
+Hpi:
+ backend_config:
+ onnx_runtime:
+ cpu_num_threads: 8
+ openvino:
+ cpu_num_threads: 8
+ paddle_infer:
+ cpu_num_threads: 8
+ enable_log_info: false
+ paddle_tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ enable_log_info: false
+ max_batch_size: null
+ tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ max_batch_size: null
+ selected_backends:
+ cpu: onnx_runtime
+ gpu: tensorrt
+ supported_backends:
+ cpu:
+ - paddle_infer
+ - openvino
+ - onnx_runtime
+ gpu:
+ - paddle_infer
+ - paddle_tensorrt
+ - onnx_runtime
+ - tensorrt
diff --git a/paddlex/utils/hpi_configs/FairMOT-DLA-34.yaml b/paddlex/utils/hpi_configs/FairMOT-DLA-34.yaml
new file mode 100644
index 0000000000..37ce8e69b2
--- /dev/null
+++ b/paddlex/utils/hpi_configs/FairMOT-DLA-34.yaml
@@ -0,0 +1,65 @@
+Hpi:
+ backend_config:
+ onnx_runtime:
+ cpu_num_threads: 8
+ openvino:
+ cpu_num_threads: 8
+ paddle_infer:
+ cpu_num_threads: 8
+ enable_log_info: false
+ paddle_tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ enable_log_info: false
+ max_batch_size: null
+ tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ max_batch_size: null
+ selected_backends:
+ cpu: onnx_runtime
+ gpu: tensorrt
+ supported_backends:
+ cpu:
+ - paddle_infer
+ - openvino
+ - onnx_runtime
+ gpu:
+ - paddle_infer
+ - paddle_tensorrt
+ - onnx_runtime
+ - tensorrt