+简体中文 | [English](multi_object_tracking_en.md)
+# 多目标跟踪模块使用教程
+## 一、概述
+## 二、支持模型列表
+ 👉模型列表详情
+|FairMOT-DLA-34|MOT-16 Training Set|83.2|83.1|
+|DeepSORT_PP-YOLOE_ResNet|MOT-17 half Val Set|56.7|64.6|
+|ByteTrack_PP-YOLOE_L|MOT-17 half train Set|50.4|59.7|
+**以上模型精度指标测量自 MOT-16和MOT-17 数据集。**
+## 三、快速集成
+> ❗ 在快速集成前,请先安装 PaddleX 的 wheel 包,详细请参考 [PaddleX本地安装教程](../../../installation/installation.md)
+from paddlex.inference import create_model
+model_name = "ByteTrack_PP-YOLOE_L"
+model = create_model(model_name)
+output = model.predict("mot.png", batch_size=1)
+for res in output:
+ res.print(json_format=False)
+ res.save_to_img("./output/")
+ res.save_to_json("./output/res.json")
+关于更多 PaddleX 的单模型推理的 API 的使用方法,可以参考[PaddleX单模型Python脚本使用说明](../../instructions/model_python_API.md)。
+## 四、二次开发
+### 4.1 数据准备
+在进行模型训练前,需要准备相应任务模块的数据集。PaddleX 针对每一个模块提供了数据校验功能,**只有通过数据校验的数据才可以进行模型训练**。此外,PaddleX为每一个模块都提供了Demo数据集,您可以基于官方提供的 Demo 数据完成后续的开发。可以参考[PaddleX多目标跟踪任务模块数据标注教程](待填充)。
+#### 4.1.1 Demo 数据下载
+您可以参考下面的命令将 Demo 数据集下载到指定文件夹:
+cd /path/to/paddlex
+wget https://paddle-model-ecology.bj.bcebos.com/paddlex/data/mot_examples.tar -P ./dataset
+tar -xf ./dataset/mot_examples.tar -C ./dataset/
+#### 4.1.2 数据校验
+python main.py -c paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml \
+ -o Global.mode=check_dataset \
+ -o Global.dataset_dir=./dataset/mot_examples
+执行上述命令后,PaddleX 会对数据集进行校验,并统计数据集的基本信息,命令运行成功后会在log中打印出`Check dataset passed !`信息。校验结果文件保存在`./output/check_dataset_result.json`,同时相关产出会保存在当前目录的`./output/check_dataset`目录下,产出目录中包括可视化的示例样本图片和样本分布直方图。
+ 👉 校验结果详情(点击展开)
+ "done_flag": true,
+ "check_pass": true,
+ "attributes": {
+ "num_classes": 1,
+ "train_samples": 5316,
+ "train_sample_paths": [
+ "check_dataset/demo_img/train/0.jpg",
+ "check_dataset/demo_img/train/1.jpg",
+ "check_dataset/demo_img/train/2.jpg",
+ ],
+ "val_samples": 5316,
+ "val_sample_paths": [
+ "check_dataset/demo_img/val/0.jpg",
+ "check_dataset/demo_img/val/1.jpg",
+ "check_dataset/demo_img/val/2.jpg",
+ ]
+ },
+ "analysis": null,
+ "dataset_path": "dataset/mot_datasetss",
+ "show_type": "image",
+ "dataset_type": "COCODetDataset"
+上述校验结果中,`check_pass` 为 `True` 表示数据集格式符合要求,其他部分指标的说明如下:
+* `attributes.train_samples`:该数据集训练集样本数量为 5316;
+* `attributes.val_samples`:该数据集验证集样本数量为 5316;
+* `attributes.train_sample_paths`:该数据集训练集样本可视化图片相对路径列表;
+* `attributes.val_sample_paths`:该数据集验证集样本可视化图片相对路径列表;
+### 4.2 模型训练
+python main.py -c paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml \
+ -o Global.mode=train \
+ -o Global.dataset_dir=./dataset/mot_examples
+* 指定模型的`.yaml` 配置文件路径(此处为`ByteTrack_PP-YOLOE_L.yaml`)
+* 指定模式为模型训练:`-o Global.mode=train`
+* 指定训练数据集路径:`-o Global.dataset_dir`
+其他相关参数均可通过修改`.yaml`配置文件中的`Global`和`Train`下的字段来进行设置,也可以通过在命令行中追加参数来进行调整。如指定前 2 卡 gpu 训练:`-o Global.device=gpu:0,1`;设置训练轮次数为 10:`-o Train.epochs_iters=10`。更多可修改的参数及其详细解释,可以查阅查阅模型对应任务模块的配置文件说明[PaddleX通用模型配置文件参数说明](../../instructions/config_parameters_common.md)。
+ 👉 更多说明(点击展开)
+* 模型训练过程中,PaddleX 会自动保存模型权重文件,默认为`output`,如需指定保存路径,可通过配置文件中 `-o Global.output` 字段进行设置。
+* PaddleX 对您屏蔽了动态图权重和静态图权重的概念。在模型训练的过程中,会同时产出动态图和静态图的权重,在模型推理时,默认选择静态图权重推理。
+* 训练其他模型时,需要的指定相应的配置文件,模型和配置的文件的对应关系,可以查阅[PaddleX模型列表(CPU/GPU)](../../../support_list/models_list.md)。
+* `train_result.json`:训练结果记录文件,记录了训练任务是否正常完成,以及产出的权重指标、相关文件路径等;
+* `train.log`:训练日志文件,记录了训练过程中的模型指标变化、loss 变化等;
+* `config.yaml`:训练配置文件,记录了本次训练的超参数的配置;
+* `.pdparams`、`.pdema`、`.pdopt.pdstate`、`.pdiparams`、`.pdmodel`:模型权重相关文件,包括网络参数、优化器、EMA、静态图网络参数、静态图网络结构等;
+### **4.3 模型评估**
+在完成模型训练后,可以对指定的模型权重文件在验证集上进行评估,验证模型精度。使用 PaddleX 进行模型评估,一条命令即可完成模型的评估:
+python main.py -c paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml \
+ -o Global.mode=evaluate \
+ -o Global.dataset_dir=./dataset/mot_examples
+* 指定模型的`.yaml` 配置文件路径(此处为`ByteTrack_PP-YOLOE_L.yaml`)
+* 指定模式为模型评估:`-o Global.mode=evaluate`
+* 指定验证数据集路径:`-o Global.dataset_dir`
+ 👉 更多说明(点击展开)
+在模型评估时,需要指定模型权重文件路径,每个配置文件中都内置了默认的权重保存路径,如需要改变,只需要通过追加命令行参数的形式进行设置即可,如`-o Evaluate.weight_path=./output/best_model/model.pdparams`。
+在完成模型评估后,会产出`evaluate_result.json,`记录评估的结果,具体来说,记录了评估任务是否正常完成,以及模型的评估指标,包含 MOTA;
+### **4.4 模型推理**
+#### 4.4.1 模型推理
+* 通过命令行的方式进行推理预测,只需如下一条命令,运行以下代码前,请您下载[示例图片](待填充)到本地。
+python main.py -c paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml \
+ -o Global.mode=predict \
+ -o Predict.model_dir="./output/best_model/inference" \
+ -o Predict.input="mot.png"
+* 指定模型的`.yaml` 配置文件路径(此处为`ByteTrack_PP-YOLOE_L.yaml`)
+* 指定模式为模型推理预测:`-o Global.mode=predict`
+* 指定模型权重路径:`-o Predict.model_dir="./output/best_model/inference"`
+* 指定输入数据路径:`-o Predict.input="..."`
+#### 4.4.2 模型集成
+模型可以直接集成到 PaddleX 产线中,也可以直接集成到您自己的项目中。
+您产出的权重可以直接集成到多目标跟踪模块中,可以参考[快速集成](#三快速集成)的 Python 示例代码,只需要将模型替换为你训练的到的模型路径即可。
diff --git a/paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml b/paddlex/configs/multi_object_tracking/ByteTrack_PP-YOLOE_L.yaml
+ model: ByteTrack_PP-YOLOE_L
+ mode: check_dataset # check_dataset/train/evaluate/predict
+ dataset_dir: "/paddle/dataset/paddlex/det/det_coco_examples"
+ device: gpu:0,1,2,3
+ output: "output"
+ convert:
+ enable: False
+ src_dataset_type: null
+ split:
+ enable: False
+ train_percent: null
+ val_percent: null
+ num_classes: 1
+ epochs_iters: 50
+ batch_size: #2
+ learning_rate: #0.0001
+ pretrain_weight_path: null
+ warmup_steps: #100
+ resume_path: null
+ log_interval: 10
+ eval_interval: 1
+ weight_path: "output/best_model/best_model.pdparams"
+ log_interval: 10
+ batch_size: 1
+ model_dir: "output/best_model/inference"
+ input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
+ kernel_option:
+ run_mode: paddle
+ weight_path: # https://paddledet.bj.bcebos.com/models/detr_r50_1x_coco.pdparams
+ model: DeepSORT_PP-YOLOE_ResNet
+ mode: check_dataset # check_dataset/train/evaluate/predict
+ dataset_dir: "/paddle/dataset/paddlex/det/det_coco_examples"
+ device: gpu:0,1,2,3
+ output: "output"
+ convert:
+ enable: False
+ src_dataset_type: null
+ split:
+ enable: False
+ train_percent: null
+ val_percent: null
+ num_classes: 1
+ epochs_iters: 50
+ batch_size: 2
+ learning_rate: #0.0001
+ pretrain_weight_path: null
+ warmup_steps: #100
+ resume_path: null
+ log_interval: 10
+ eval_interval: 1
+ weight_path: "output/best_model/best_model.pdparams"
+ log_interval: 10
+ batch_size: 1
+ model_dir: "output/best_model/inference"
+ input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
+ kernel_option:
+ run_mode: paddle
+ weight_path: # https://paddledet.bj.bcebos.com/models/detr_r50_1x_coco.pdparams
+ model: FairMOT-DLA-34
+ mode: check_dataset # check_dataset/train/evaluate/predict
+ dataset_dir: "/paddle/dataset/paddlex/det/det_coco_examples"
+ device: gpu:0,1,2,3
+ output: "output"
+ convert:
+ enable: False
+ src_dataset_type: null
+ split:
+ enable: False
+ train_percent: null
+ val_percent: null
+ num_classes: 1
+ epochs_iters: # 50
+ batch_size: # 2
+ learning_rate: # 0.0001
+ pretrain_weight_path: null
+ warmup_steps: #100
+ resume_path: null
+ log_interval: 10
+ eval_interval: 1
+ weight_path: "output/best_model/best_model.pdparams"
+ log_interval: 10
+ batch_size: 1
+ model_dir: "output/best_model/inference"
+ input: "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/general_object_detection_002.png"
+ kernel_option:
+ run_mode: paddle
+ weight_path: # https://paddledet.bj.bcebos.com/models/detr_r50_1x_coco.pdparams
+from .multi_object_tracking import (
+ MOTDatasetChecker,
+ MOTTrainer,
+ MOTEvaluator,
+ MOTExportor,
from .ts_forecast import TSFCDatasetChecker, TSFCTrainer, TSFCEvaluator
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .trainer import MOTTrainer
+from .dataset_checker import MOTDatasetChecker
+from .evaluator import MOTEvaluator
+from .exportor import MOTExportor
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import os.path as osp
+from pathlib import Path
+from collections import defaultdict, Counter
+from PIL import Image
+import json
+from pycocotools.coco import COCO
+from ...base import BaseDatasetChecker
+from .dataset_src import check, convert, split_dataset, deep_analyse
+from ..model_list import MODELS
+class MOTDatasetChecker(BaseDatasetChecker):
+ """Dataset Checker for Multi-Object Tracking Model"""
+ entities = MODELS
+ sample_num = 10
+ def get_dataset_root(self, dataset_dir: str) -> str:
+ """find the dataset root dir
+ Args:
+ dataset_dir (str): the directory that contain dataset.
+ Returns:
+ str: the root directory of dataset.
+ """
+ # anno_dirs = list(Path(dataset_dir).glob("**/images"))
+ # assert len(anno_dirs) == 1
+ # dataset_dir = anno_dirs[0].parent.as_posix()
+ return dataset_dir
+ def convert_dataset(self, src_dataset_dir: str) -> str:
+ """convert the dataset from other type to specified type
+ Args:
+ src_dataset_dir (str): the root directory of dataset.
+ Returns:
+ str: the root directory of converted dataset.
+ """
+ return convert(
+ self.check_dataset_config.convert.src_dataset_type, src_dataset_dir
+ )
+ def split_dataset(self, src_dataset_dir: str) -> str:
+ """repartition the train and validation dataset
+ Args:
+ src_dataset_dir (str): the root directory of dataset.
+ Returns:
+ str: the root directory of splited dataset.
+ """
+ return split_dataset(
+ src_dataset_dir,
+ self.check_dataset_config.split.train_percent,
+ self.check_dataset_config.split.val_percent,
+ )
+ def check_dataset(self, dataset_dir: str, sample_num: int = sample_num) -> dict:
+ """check if the dataset meets the specifications and get dataset summary
+ Args:
+ dataset_dir (str): the root directory of dataset.
+ sample_num (int): the number to be sampled.
+ Returns:
+ dict: dataset summary.
+ """
+ return check(dataset_dir, self.output, self.global_config['model'])
+ def analyse(self, dataset_dir: str) -> dict:
+ """deep analyse dataset
+ Args:
+ dataset_dir (str): the root directory of dataset.
+ Returns:
+ dict: the deep analysis results.
+ """
+ return deep_analyse(dataset_dir, self.output, self.global_config['model'])
+ def get_show_type(self) -> str:
+ """get the show type of dataset
+ Returns:
+ str: show type
+ """
+ return "image"
+ def get_dataset_type(self) -> str:
+ """return the dataset type
+ Returns:
+ str: dataset type
+ """
+ return "COCODetDataset"
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .check_dataset import check
+from .convert_dataset import convert
+from .split_dataset import split_dataset
+from .analyse_dataset import deep_analyse
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import json
+import platform
+from pathlib import Path
+from collections import defaultdict
+from PIL import Image
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import font_manager
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+from pycocotools.coco import COCO
+from .....utils.fonts import PINGFANG_FONT_FILE_PATH
+def deep_analyse(dataset_dir, output, model_name):
+ """class analysis for dataset"""
+ if model_name == 'FairMOT-DLA-34':
+ return None
+ else:
+ tags = ["train", "val"]
+ all_instances = 0
+ for tag in tags:
+ annotations_path = os.path.abspath(
+ os.path.join(dataset_dir, f"annotations/{tag}_half.json")
+ )
+ labels_cnt = defaultdict(list)
+ coco = COCO(annotations_path)
+ cat_ids = coco.getCatIds()
+ for cat_id in cat_ids:
+ cat_name = coco.loadCats(ids=cat_id)[0]["name"]
+ labels_cnt[cat_name] = labels_cnt[cat_name] + coco.getAnnIds(catIds=cat_id)
+ all_instances += len(labels_cnt[cat_name])
+ if tag == "train":
+ cnts_train = [len(cat_ids) for cat_name, cat_ids in labels_cnt.items()]
+ elif tag == "val":
+ cnts_val = [len(cat_ids) for cat_name, cat_ids in labels_cnt.items()]
+ classes = [cat_name for cat_name, cat_ids in labels_cnt.items()]
+ sorted_id = sorted(
+ range(len(cnts_train)), key=lambda k: cnts_train[k], reverse=True
+ )
+ cnts_train_sorted = sorted(cnts_train, reverse=True)
+ cnts_val_sorted = [cnts_val[index] for index in sorted_id]
+ classes_sorted = [classes[index] for index in sorted_id]
+ x = np.arange(len(classes))
+ width = 0.5
+ # bar
+ os_system = platform.system().lower()
+ if os_system == "windows":
+ plt.rcParams["font.sans-serif"] = "FangSong"
+ else:
+ font = font_manager.FontProperties(fname=PINGFANG_FONT_FILE_PATH)
+ fig, ax = plt.subplots(figsize=(max(8, int(len(classes) / 5)), 5), dpi=120)
+ ax.bar(x, cnts_train_sorted, width=0.5, label="train")
+ ax.bar(x + width, cnts_val_sorted, width=0.5, label="val")
+ plt.xticks(
+ x + width / 2,
+ classes_sorted,
+ rotation=90,
+ fontproperties=None if os_system == "windows" else font,
+ )
+ ax.set_ylabel("Counts")
+ plt.legend()
+ fig.tight_layout()
+ fig_path = os.path.join(output, "histogram.png")
+ fig.savefig(fig_path)
+ return {"histogram": os.path.join("check_dataset", "histogram.png")}
diff --git a/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/check_dataset.py b/paddlex/modules/multi_object_tracking/dataset_checker/dataset_src/check_dataset.py
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import os.path as osp
+from collections import defaultdict, Counter
+from pathlib import Path
+import PIL
+from PIL import Image, ImageOps, ImageDraw, ImageFont
+import json
+from pycocotools.coco import COCO
+import numpy as np
+from .....utils.errors import DatasetFileNotFoundError
+from .utils.visualizer import draw_bbox
+from .....utils.fonts import PINGFANG_FONT_FILE_PATH
+def check(dataset_dir, output, model_name):
+ """check dataset"""
+ dataset_dir = osp.abspath(dataset_dir)
+ if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
+ raise DatasetFileNotFoundError(file_path=dataset_dir)
+ sample_cnts = dict()
+ sample_paths = defaultdict(list)
+ im_sizes = defaultdict(Counter)
+ if model_name == 'FairMOT-DLA-34':
+ num_class = 1
+ tags = {'train':['mot17.train'], 'val':['mot17.train']}
+ for tag in tags.keys():
+ default_image_lists = tags[tag]
+ img_files = {}
+ samp_num = 0
+ for data_name in default_image_lists:
+ list_path = osp.join(dataset_dir, 'image_lists', data_name)
+ with open(list_path, 'r') as file:
+ img_files[data_name] = file.readlines()
+ img_files[data_name] = [
+ os.path.join(dataset_dir, x.strip())
+ for x in img_files[data_name]
+ ]
+ img_files[data_name] = list(
+ filter(lambda x: len(x) > 0, img_files[data_name]))
+ image_info = []
+ for data_name in img_files.keys():
+ samp_num += len(img_files[data_name])
+ image_info += img_files[data_name]
+ label_files_info = [
+ x.replace('images', 'labels_with_ids').replace(
+ '.png', '.txt').replace('.jpg', '.txt')
+ for x in image_info
+ ]
+ sample_num = min(10, samp_num)
+ for i in range(sample_num):
+ img_path = image_info[i]
+ label_path = label_files_info[i]
+ labels = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6) # [gt_class, gt_identity, cx, cy, w, h]
+ if not osp.exists(img_path):
+ raise DatasetFileNotFoundError(file_path=img_path)
+ img = Image.open(img_path)
+ labels[:, [2, 4]] *= img.width
+ labels[:, [3, 5]] *= img.height
+ img = ImageOps.exif_transpose(img)
+ vis_im = draw_bbox_with_labels(img, labels)
+ vis_save_dir = osp.join(output, "demo_img", tag)
+ vis_path = osp.join(vis_save_dir, str(i)+'.'+img_path.split('.')[1])
+ Path(vis_path).parent.mkdir(parents=True, exist_ok=True)
+ vis_im.save(vis_path)
+ sample_path = osp.join(
+ "check_dataset", os.path.relpath(vis_path, output)
+ )
+ sample_paths[tag].append(sample_path)
+ attrs = {}
+ attrs["num_classes"] = num_class
+ attrs["train_samples"] = samp_num
+ attrs["train_sample_paths"] = sample_paths['train']
+ attrs["val_samples"] = len(img_files['mot17.train'])
+ attrs["val_sample_paths"] = sample_paths['val']
+ else:
+ tags = ["train_half", "val_half"]
+ for _, tag in enumerate(tags):
+ file_list = osp.join(dataset_dir, f"annotations/{tag}.json")
+ if not osp.exists(file_list):
+ if tag in ("train_half", "val_half"):
+ # train and val file lists must exist
+ raise DatasetFileNotFoundError(
+ file_path=file_list,
+ solution=f"Ensure that both `train_half.json` and `val_half.json` exist in \
+ {dataset_dir}/annotations",
+ )
+ else:
+ continue
+ else:
+ with open(file_list, "r", encoding="utf-8") as f:
+ jsondata = json.load(f)
+ coco = COCO(file_list)
+ num_class = len(coco.getCatIds())
+ vis_save_dir = osp.join(output, "demo_img")
+ image_info = jsondata["images"]
+ sample_cnts[tag] = len(image_info)
+ sample_num = min(10, len(image_info))
+ img_types = {"train_half": 'train', "val_half": 'val'}
+ for i in range(sample_num):
+ file_name = image_info[i]["file_name"]
+ img_id = image_info[i]["id"]
+ img_type = img_types[tag]
+ img_path = osp.join(dataset_dir, "images", 'train', file_name)
+ if not osp.exists(img_path):
+ raise DatasetFileNotFoundError(file_path=img_path)
+ img = Image.open(img_path)
+ img = ImageOps.exif_transpose(img)
+ vis_im = draw_bbox(img, coco, img_id)
+ vis_path = osp.join(vis_save_dir, file_name)
+ Path(vis_path).parent.mkdir(parents=True, exist_ok=True)
+ vis_im.save(vis_path)
+ sample_path = osp.join(
+ "check_dataset", os.path.relpath(vis_path, output)
+ )
+ sample_paths[tag].append(sample_path)
+ attrs = {}
+ attrs["num_classes"] = num_class
+ attrs["train_samples"] = sample_cnts["train_half"]
+ attrs["train_sample_paths"] = sample_paths["train_half"]
+ attrs["val_samples"] = sample_cnts["val_half"]
+ attrs["val_sample_paths"] = sample_paths["val_half"]
+ return attrs
+def font_colormap(color_index):
+ """
+ Get font color according to the index of colormap
+ """
+ dark = np.array([0x14, 0x0E, 0x35])
+ light = np.array([0xFF, 0xFF, 0xFF])
+ light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+ if color_index in light_indexs:
+ return light.astype("int32")
+ else:
+ return dark.astype("int32")
+def colormap(rgb=False):
+ """
+ Get colormap
+ The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
+ utils/colormap.py
+ """
+ color_list = np.array(
+ [
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xCC,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0x66,
+ 0x00,
+ 0x66,
+ 0xFF,
+ 0xCC,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x4D,
+ 0x00,
+ 0x80,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0xB2,
+ 0x00,
+ 0x1A,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0xE5,
+ 0xFF,
+ 0x99,
+ 0x00,
+ 0x33,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x33,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0x99,
+ 0xFF,
+ 0xE5,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0x1A,
+ 0x00,
+ 0xB2,
+ 0xFF,
+ 0x80,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0x4D,
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3))
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list.astype("int32")
+def draw_bbox_with_labels(image, label):
+ """
+ Draw bbox on image
+ """
+ font_size = 12
+ font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+ image = image.convert("RGB")
+ draw = ImageDraw.Draw(image)
+ image_size = image.size
+ width = int(max(image_size) * 0.005)
+ catid2color = {}
+ catid2fontcolor = {}
+ catid_num_dict = {}
+ color_list = colormap(rgb=True)
+ annotations = label
+ for ann in annotations:
+ catid = int(ann[1])
+ catid_num_dict[catid] = catid_num_dict.get(catid, 0) + 1
+ for i, (catid, _) in enumerate(
+ sorted(catid_num_dict.items(), key=lambda x: x[1], reverse=True)
+ ):
+ if catid not in catid2color:
+ color_index = i % len(color_list)
+ catid2color[catid] = color_list[color_index]
+ catid2fontcolor[catid] = font_colormap(color_index)
+ for ann in annotations:
+ catid = int(ann[1])
+ bbox = ann[2:]
+ color = tuple(catid2color[catid])
+ font_color = tuple(catid2fontcolor[catid])
+ if len(bbox) == 4:
+ # draw bbox
+ cx, cy, w, h = bbox
+ xmin = cx - 0.5*w
+ ymin = cy - 0.5*h
+ xmax = cx + 0.5*w
+ ymax = cy + 0.5*h
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
+ width=width,
+ fill=color,
+ )
+ else:
+ logging.info("Error: The shape of bbox must be [M, 4]!")
+ # draw label
+ label = 'targets_' + str(catid)
+ text = "{}".format(label)
+ if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+ tw, th = draw.textsize(text, font=font)
+ else:
+ left, top, right, bottom = draw.textbbox((0, 0), text, font)
+ tw, th = right - left, bottom - top
+ if ymin < th:
+ draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
+ draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
+ else:
+ draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
+ draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
+ return image
\ No newline at end of file
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import shutil
+import json
+import random
+import xml.etree.ElementTree as ET
+from tqdm import tqdm
+from .....utils.file_interface import custom_open, write_json_file
+from .....utils.errors import ConvertFailedError
+from .....utils.logging import info, warning
+class Indexer(object):
+ """Indexer"""
+ def __init__(self):
+ """init indexer"""
+ self._map = {}
+ self.idx = 0
+ def get_id(self, key):
+ """get id by key"""
+ if key not in self._map:
+ self.idx += 1
+ self._map[key] = self.idx
+ return self._map[key]
+ def get_list(self, key_name):
+ """return list containing key and id"""
+ map_list = []
+ for key in self._map:
+ val = self._map[key]
+ map_list.append({key_name: key, "id": val})
+ return map_list
+class Extension(object):
+ """Extension"""
+ def __init__(self, exts_list):
+ """init extension"""
+ self._exts_list = ["." + ext for ext in exts_list]
+ def __iter__(self):
+ """iterator"""
+ return iter(self._exts_list)
+ def update(self, ext):
+ """update extension"""
+ self._exts_list.remove(ext)
+ self._exts_list.insert(0, ext)
+def check_src_dataset(root_dir, dataset_type):
+ """check src dataset format validity"""
+ if dataset_type in ("VOC", "VOCWithUnlabeled"):
+ anno_suffix = ".xml"
+ elif dataset_type in ("LabelMe", "LabelMeWithUnlabeled"):
+ anno_suffix = ".json"
+ else:
+ raise ConvertFailedError(
+ message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 VOC、LabelMe 和 VOCWithUnlabeled、LabelMeWithUnlabeled 格式。"
+ )
+ err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
+ anno_map = {}
+ for dst_anno, src_anno in [
+ ("instance_train.json", "train_anno_list.txt"),
+ ("instance_val.json", "val_anno_list.txt"),
+ ]:
+ src_anno_path = os.path.join(root_dir, src_anno)
+ if not os.path.exists(src_anno_path):
+ if dst_anno == "instance_train.json":
+ raise ConvertFailedError(
+ message=f"{err_msg_prefix}保证{src_anno_path}文件存在。"
+ )
+ continue
+ with custom_open(src_anno_path, "r") as f:
+ anno_list = f.readlines()
+ for anno_fn in anno_list:
+ anno_fn = anno_fn.strip().split(" ")[-1]
+ anno_path = os.path.join(root_dir, anno_fn)
+ if not os.path.exists(anno_path):
+ raise ConvertFailedError(
+ message=f'{err_msg_prefix}保证"{src_anno_path}"中的"{anno_fn}"文件存在。'
+ )
+ anno_map[dst_anno] = src_anno_path
+ return anno_map
+def convert(dataset_type, input_dir):
+ """convert dataset to coco format"""
+ # check format validity
+ anno_map = check_src_dataset(input_dir, dataset_type)
+ (
+ convert_voc_dataset(input_dir, anno_map)
+ if dataset_type in ("VOC", "VOCWithUnlabeled")
+ else convert_labelme_dataset(input_dir, anno_map)
+ )
+def split_anno_list(root_dir, anno_map):
+ """Split anno list to 80% train and 20% val"""
+ train_anno_list = []
+ val_anno_list = []
+ anno_list_bak = os.path.join(root_dir, "train_anno_list.txt.bak")
+ shutil.move(anno_map["instance_train.json"], anno_list_bak),
+ with custom_open(anno_list_bak, "r") as f:
+ src_anno = f.readlines()
+ random.shuffle(src_anno)
+ train_anno_list = src_anno[: int(len(src_anno) * 0.8)]
+ val_anno_list = src_anno[int(len(src_anno) * 0.8) :]
+ with custom_open(os.path.join(root_dir, "train_anno_list.txt"), "w") as f:
+ f.writelines(train_anno_list)
+ with custom_open(os.path.join(root_dir, "val_anno_list.txt"), "w") as f:
+ f.writelines(val_anno_list)
+ anno_map["instance_train.json"] = os.path.join(root_dir, "train_anno_list.txt")
+ anno_map["instance_val.json"] = os.path.join(root_dir, "val_anno_list.txt")
+ msg = f"{os.path.join(root_dir,'val_anno_list.txt')}不存在,数据集已默认按照80%训练集,20%验证集划分,\
+ 且将原始'train_anno_list.txt'重命名为'train_anno_list.txt.bak'."
+ warning(msg)
+ return anno_map
+def convert_labelme_dataset(root_dir, anno_map):
+ """convert dataset labeled by LabelMe to coco format"""
+ label_indexer = Indexer()
+ img_indexer = Indexer()
+ annotations_dir = os.path.join(root_dir, "annotations")
+ if not os.path.exists(annotations_dir):
+ os.makedirs(annotations_dir)
+ # FIXME(gaotingquan): support lmssl
+ unlabeled_path = os.path.join(root_dir, "unlabeled.txt")
+ if os.path.exists(unlabeled_path):
+ shutil.move(unlabeled_path, os.path.join(annotations_dir, "unlabeled.txt"))
+ # 不存在val_anno_list,对原始数据集进行划分
+ if "instance_val.json" not in anno_map:
+ anno_map = split_anno_list(root_dir, anno_map)
+ for dst_anno in anno_map:
+ labelme2coco(
+ label_indexer,
+ img_indexer,
+ root_dir,
+ anno_map[dst_anno],
+ os.path.join(annotations_dir, dst_anno),
+ )
+def labelme2coco(label_indexer, img_indexer, root_dir, anno_path, save_path):
+ """convert json files generated by LabelMe to coco format and save to files"""
+ with custom_open(anno_path, "r") as f:
+ json_list = f.readlines()
+ anno_num = 0
+ anno_list = []
+ image_list = []
+ info(f"Start loading json annotation files from {anno_path} ...")
+ for json_path in tqdm(json_list):
+ json_path = json_path.strip()
+ if not json_path.endswith(".json"):
+ info(f'An illegal json path("{json_path}") found! Has been ignored.')
+ continue
+ with custom_open(os.path.join(root_dir, json_path.strip()), "r") as f:
+ labelme_data = json.load(f)
+ img_id = img_indexer.get_id(labelme_data["imagePath"])
+ image_list.append(
+ {
+ "id": img_id,
+ "file_name": labelme_data["imagePath"].split("/")[-1],
+ "width": labelme_data["imageWidth"],
+ "height": labelme_data["imageHeight"],
+ }
+ )
+ for shape in labelme_data["shapes"]:
+ assert shape["shape_type"] == "rectangle", "Only rectangle are supported."
+ category_id = label_indexer.get_id(shape["label"])
+ (x1, y1), (x2, y2) = shape["points"]
+ x1, x2 = sorted([x1, x2])
+ y1, y2 = sorted([y1, y2])
+ bbox = list(map(float, [x1, y1, x2 - x1, y2 - y1]))
+ anno_num += 1
+ anno_list.append(
+ {
+ "image_id": img_id,
+ "bbox": bbox,
+ "category_id": category_id,
+ "id": anno_num,
+ "iscrowd": 0,
+ "area": bbox[2] * bbox[3],
+ "ignore": 0,
+ }
+ )
+ category_list = label_indexer.get_list(key_name="name")
+ data_coco = {
+ "images": image_list,
+ "categories": category_list,
+ "annotations": anno_list,
+ }
+ write_json_file(data_coco, save_path)
+ info(f"The converted annotations has been save to {save_path}.")
+def convert_voc_dataset(root_dir, anno_map):
+ """convert VOC format dataset to coco format"""
+ label_indexer = Indexer()
+ img_indexer = Indexer()
+ annotations_dir = os.path.join(root_dir, "annotations")
+ if not os.path.exists(annotations_dir):
+ os.makedirs(annotations_dir)
+ # FIXME(gaotingquan): support lmssl
+ unlabeled_path = os.path.join(root_dir, "unlabeled.txt")
+ if os.path.exists(unlabeled_path):
+ shutil.move(unlabeled_path, os.path.join(annotations_dir, "unlabeled.txt"))
+ # 不存在val_anno_list,对原始数据集进行划分
+ if "instance_val.json" not in anno_map:
+ anno_map = split_anno_list(root_dir, anno_map)
+ for dst_anno in anno_map:
+ ann_paths = voc_get_label_anno(root_dir, anno_map[dst_anno])
+ voc_xmls_to_cocojson(
+ root_dir=root_dir,
+ annotation_paths=ann_paths,
+ label_indexer=label_indexer,
+ img_indexer=img_indexer,
+ output=annotations_dir,
+ output_file=dst_anno,
+ )
+def voc_get_label_anno(root_dir, anno_path):
+ """
+ Read VOC format annotation file.
+ Args:
+ root_dir (str): The directoty of VOC annotation file.
+ anno_path (str): The annoation file path.
+ Returns:
+ tuple: A tuple of two elements, the first of which is of type dict, representing the mapping between tag names
+ and their corresponding ids, and the second of type list, representing the list of paths to all annotated files.
+ """
+ if not os.path.exists(anno_path):
+ info(f"The annotation file {anno_path} don't exists, has been ignored!")
+ return []
+ with custom_open(anno_path, "r") as f:
+ ann_ids = f.readlines()
+ ann_paths = []
+ info(f"Start loading xml annotation files from {anno_path} ...")
+ for aid in ann_ids:
+ aid = aid.strip().split(" ")[-1]
+ if not aid.endswith(".xml"):
+ info(f'An illegal xml path("{aid}") found! Has been ignored.')
+ continue
+ ann_path = os.path.join(root_dir, aid)
+ ann_paths.append(ann_path)
+ return ann_paths
+def voc_get_image_info(annotation_root, img_indexer):
+ """
+ Get the iamge info from VOC annotation file.
+ Args:
+ annotation_root: The annotation root.
+ img_indexer: indexer to get image id by filename.
+ Returns:
+ dict: The image info.
+ Raises:
+ AssertionError: When filename cannot be found in 'annotation_root'.
+ """
+ filename = annotation_root.findtext("filename")
+ assert filename is not None, filename
+ img_name = os.path.basename(filename)
+ im_id = img_indexer.get_id(filename)
+ size = annotation_root.find("size")
+ width = float(size.findtext("width"))
+ height = float(size.findtext("height"))
+ image_info = {"file_name": filename, "height": height, "width": width, "id": im_id}
+ return image_info
+def voc_get_coco_annotation(obj, label_indexer):
+ """
+ Convert VOC format annotation to COCO format.
+ Args:
+ obj: a obj in VOC.
+ label_indexer: indexer to get category id by label name.
+ Returns:
+ dict: A dict with the COCO format annotation info.
+ Raises:
+ AssertionError: When the width or height of the annotation box is illegal.
+ """
+ label = obj.findtext("name")
+ category_id = label_indexer.get_id(label)
+ bndbox = obj.find("bndbox")
+ xmin = float(bndbox.findtext("xmin"))
+ ymin = float(bndbox.findtext("ymin"))
+ xmax = float(bndbox.findtext("xmax"))
+ ymax = float(bndbox.findtext("ymax"))
+ if xmin > xmax or ymin > ymax:
+ temp = xmin
+ xmin = min(xmin, xmax)
+ xmax = max(temp, xmax)
+ temp = ymin
+ ymin = min(ymin, ymax)
+ ymax = max(temp, ymax)
+ o_width = xmax - xmin
+ o_height = ymax - ymin
+ anno = {
+ "area": o_width * o_height,
+ "iscrowd": 0,
+ "bbox": [xmin, ymin, o_width, o_height],
+ "category_id": category_id,
+ "ignore": 0,
+ }
+ return anno
+def voc_xmls_to_cocojson(
+ root_dir, annotation_paths, label_indexer, img_indexer, output, output_file
+ """
+ Convert VOC format data to COCO format.
+ Args:
+ annotation_paths (list): A list of paths to the XML files.
+ label_indexer: indexer to get category id by label name.
+ img_indexer: indexer to get image id by filename.
+ output (str): The directory to save output JSON file.
+ output_file (str): Output JSON file name.
+ Returns:
+ None
+ """
+ extension_list = ["jpg", "png", "jpeg", "JPG", "PNG", "JPEG"]
+ suffixs = Extension(extension_list)
+ def match(root_dir, prefilename, prexlm_name):
+ """matching extension"""
+ for ext in suffixs:
+ if os.path.exists(os.path.join(root_dir, "images", prefilename + ext)):
+ suffixs.update(ext)
+ return prefilename + ext
+ elif os.path.exists(os.path.join(root_dir, "images", prexlm_name + ext)):
+ suffixs.update(ext)
+ return prexlm_name + ext
+ return None
+ output_json_dict = {
+ "images": [],
+ "type": "instances",
+ "annotations": [],
+ "categories": [],
+ }
+ bnd_id = 1 # bounding box start id
+ info("Start converting !")
+ for a_path in tqdm(annotation_paths):
+ # Read annotation xml
+ ann_tree = ET.parse(a_path)
+ ann_root = ann_tree.getroot()
+ file_name = ann_root.find("filename")
+ prefile_name = file_name.text.split(".")[0]
+ prexlm_name = os.path.basename(a_path).split(".")[0]
+ # 根据file_name 和 xlm_name 分别匹配查找图片
+ f_name = match(root_dir, prefile_name, prexlm_name)
+ if f_name is not None:
+ file_name.text = f_name
+ else:
+ prefile_name_set = set({prefile_name, prexlm_name})
+ prefile_name_set = ",".join(prefile_name_set)
+ suffix_set = ",".join(extension_list)
+ images_path = os.path.join(root_dir, "images")
+ info(
+ f"{images_path}/{{{prefile_name_set}}}.{{{suffix_set}}} both not exists,will be skipped."
+ )
+ continue
+ img_info = voc_get_image_info(ann_root, img_indexer)
+ output_json_dict["images"].append(img_info)
+ for obj in ann_root.findall("object"):
+ if obj.find("bndbox") is None: # Skip the ojbect wihtout bndbox
+ continue
+ ann = voc_get_coco_annotation(obj=obj, label_indexer=label_indexer)
+ ann.update({"image_id": img_info["id"], "id": bnd_id})
+ output_json_dict["annotations"].append(ann)
+ bnd_id = bnd_id + 1
+ output_json_dict["categories"] = label_indexer.get_list(key_name="name")
+ output_file = os.path.join(output, output_file)
+ write_json_file(output_json_dict, output_file)
+ info(f"The converted annotations has been save to {output_file}.")
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import shutil
+import random
+import json
+from tqdm import tqdm
+from .....utils.file_interface import custom_open, write_json_file
+from .....utils.logging import info
+def split_dataset(root_dir, train_rate, val_rate):
+ """split dataset"""
+ assert (
+ train_rate + val_rate == 100
+ ), f"The sum of train_rate({train_rate}), val_rate({val_rate}) should equal 100!"
+ assert (
+ train_rate > 0 and val_rate > 0
+ ), f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
+ all_image_info_list = []
+ all_category_dict = {}
+ max_image_id = 0
+ for fn in ["instance_train.json", "instance_val.json"]:
+ anno_path = os.path.join(root_dir, "annotations", fn)
+ if not os.path.exists(anno_path):
+ info(f"The annotation file {anno_path} don't exists, has been ignored!")
+ continue
+ image_info_list, category_list, max_image_id = json2list(
+ anno_path, max_image_id
+ )
+ all_image_info_list.extend(image_info_list)
+ for category in category_list:
+ if category["id"] not in all_category_dict:
+ all_category_dict[category["id"]] = category
+ total_num = len(all_image_info_list)
+ random.shuffle(all_image_info_list)
+ all_category_list = [all_category_dict[k] for k in all_category_dict]
+ start = 0
+ for fn, rate in [
+ ("instance_train.json", train_rate),
+ ("instance_val.json", val_rate),
+ ]:
+ end = start + round(total_num * rate / 100)
+ save_path = os.path.join(root_dir, "annotations", fn)
+ if os.path.exists(save_path):
+ bak_path = save_path + ".bak"
+ shutil.move(save_path, bak_path)
+ info(f"The original annotation file {fn} has been backed up to {bak_path}.")
+ assemble_write(all_image_info_list[start:end], all_category_list, save_path)
+ start = end
+ return root_dir
+def json2list(json_path, base_image_num):
+ """load json as list"""
+ assert os.path.exists(json_path), json_path
+ with custom_open(json_path, "r") as f:
+ data = json.load(f)
+ image_info_dict = {}
+ max_image_id = 0
+ for image_info in data["images"]:
+ # 得到全局唯一的image_id
+ global_image_id = image_info["id"] + base_image_num
+ max_image_id = max(global_image_id, max_image_id)
+ image_info["id"] = global_image_id
+ image_info_dict[global_image_id] = {"img": image_info, "anno": []}
+ image_info_dict = {
+ image_info["id"]: {"img": image_info, "anno": []}
+ for image_info in data["images"]
+ }
+ info(f"Start loading annotation file {json_path}...")
+ for anno in tqdm(data["annotations"]):
+ global_image_id = anno["image_id"] + base_image_num
+ anno["image_id"] = global_image_id
+ image_info_dict[global_image_id]["anno"].append(anno)
+ image_info_list = [
+ (image_info_dict[image_info]["img"], image_info_dict[image_info]["anno"])
+ for image_info in image_info_dict
+ ]
+ return image_info_list, data["categories"], max_image_id
+def assemble_write(image_info_list, category_list, save_path):
+ """assemble coco format and save to file"""
+ coco_data = {"categories": category_list}
+ image_list = [i[0] for i in image_info_list]
+ all_anno_list = []
+ for i in image_info_list:
+ all_anno_list.extend(i[1])
+ anno_list = []
+ for i, anno in enumerate(all_anno_list):
+ anno["id"] = i + 1
+ anno_list.append(anno)
+ coco_data["images"] = image_list
+ coco_data["annotations"] = anno_list
+ write_json_file(coco_data, save_path)
+ info(f"The splited annotations has been save to {save_path}.")
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- coding: UTF-8 -*-
+# Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
+Author: PaddlePaddle Authors
+import os
+import numpy as np
+import json
+from pathlib import Path
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+from pycocotools.coco import COCO
+from ......utils.fonts import PINGFANG_FONT_FILE_PATH
+from ......utils import logging
+def colormap(rgb=False):
+ """
+ Get colormap
+ The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/\
+ """
+ color_list = np.array(
+ [
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xCC,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0x66,
+ 0x00,
+ 0x66,
+ 0xFF,
+ 0xCC,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x4D,
+ 0x00,
+ 0x80,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0xB2,
+ 0x00,
+ 0x1A,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0xE5,
+ 0xFF,
+ 0x99,
+ 0x00,
+ 0x33,
+ 0xFF,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x33,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0x99,
+ 0xFF,
+ 0xE5,
+ 0x00,
+ 0x00,
+ 0xFF,
+ 0x1A,
+ 0x00,
+ 0xB2,
+ 0xFF,
+ 0x80,
+ 0x00,
+ 0xFF,
+ 0xFF,
+ 0x00,
+ 0x4D,
+ ]
+ ).astype(np.float32)
+ color_list = color_list.reshape((-1, 3))
+ if not rgb:
+ color_list = color_list[:, ::-1]
+ return color_list.astype("int32")
+def font_colormap(color_index):
+ """
+ Get font color according to the index of colormap
+ """
+ dark = np.array([0x14, 0x0E, 0x35])
+ light = np.array([0xFF, 0xFF, 0xFF])
+ light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+ if color_index in light_indexs:
+ return light.astype("int32")
+ else:
+ return dark.astype("int32")
+def draw_bbox(image, coco_info: COCO, img_id):
+ """
+ Draw bbox on image
+ """
+ try:
+ image_info = coco_info.loadImgs(img_id)[0]
+ font_size = int(0.024 * int(image_info["width"])) + 2
+ except:
+ font_size = 12
+ font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
+ image = image.convert("RGB")
+ draw = ImageDraw.Draw(image)
+ image_size = image.size
+ width = int(max(image_size) * 0.005)
+ catid2color = {}
+ catid2fontcolor = {}
+ catid_num_dict = {}
+ color_list = colormap(rgb=True)
+ annotations = coco_info.loadAnns(coco_info.getAnnIds(imgIds=img_id))
+ for ann in annotations:
+ catid = ann["category_id"]
+ catid_num_dict[catid] = catid_num_dict.get(catid, 0) + 1
+ for i, (catid, _) in enumerate(
+ sorted(catid_num_dict.items(), key=lambda x: x[1], reverse=True)
+ ):
+ if catid not in catid2color:
+ color_index = i % len(color_list)
+ catid2color[catid] = color_list[color_index]
+ catid2fontcolor[catid] = font_colormap(color_index)
+ for ann in annotations:
+ catid, bbox = ann["category_id"], ann["bbox"]
+ color = tuple(catid2color[catid])
+ font_color = tuple(catid2fontcolor[catid])
+ if len(bbox) == 4:
+ # draw bbox
+ xmin, ymin, w, h = bbox
+ xmax = xmin + w
+ ymax = ymin + h
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
+ width=width,
+ fill=color,
+ )
+ elif len(bbox) == 8:
+ x1, y1, x2, y2, x3, y3, x4, y4 = bbox
+ draw.line(
+ [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
+ width=width,
+ fill=color,
+ )
+ xmin = min(x1, x2, x3, x4)
+ ymin = min(y1, y2, y3, y4)
+ else:
+ logging.info("Error: The shape of bbox must be [M, 4] or [M, 8]!")
+ # draw label
+ label = coco_info.loadCats(catid)[0]["name"]
+ text = "{}".format(label)
+ if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
+ tw, th = draw.textsize(text, font=font)
+ else:
+ left, top, right, bottom = draw.textbbox((0, 0), text, font)
+ tw, th = right - left, bottom - top
+ if ymin < th:
+ draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
+ draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
+ else:
+ draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
+ draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
+ return image
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ..base import BaseEvaluator
+from .model_list import MODELS
+class MOTEvaluator(BaseEvaluator):
+ """Object Detection Model Evaluator"""
+ entities = MODELS
+ def update_config(self):
+ """update evalution config"""
+ if self.eval_config.log_interval:
+ self.pdx_config.update_log_interval(self.eval_config.log_interval)
+ self.pdx_config.update_dataset(self.global_config.dataset_dir, 'MOTImageFolder')
+ self.pdx_config.update_weights(self.eval_config.weight_path)
+ def get_eval_kwargs(self) -> dict:
+ """get key-value arguments of model evalution function
+ Returns:
+ dict: the arguments of evaluation function.
+ """
+ return {
+ "weight_path": self.eval_config.weight_path,
+ "device": self.get_device(using_device_number=1),
+ }
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ..base import BaseExportor
+from .model_list import MODELS
+class MOTExportor(BaseExportor):
+ """Object Detection Model Exportor"""
+ entities = MODELS
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+ "ByteTrack_PP-YOLOE_L",
+ "DeepSORT_PP-YOLOE_ResNet",
+ "FairMOT-DLA-34",
copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pathlib import Path
+import lazy_paddle as paddle
+from ..base import BaseTrainer
+from ...utils.config import AttrDict
+from ...utils import logging
+from .model_list import MODELS
+class MOTTrainer(BaseTrainer):
+ """Object Detection Model Trainer"""
+ entities = MODELS
+ def _update_dataset(self):
+ """update dataset settings"""
+ self.pdx_config.update_dataset(self.global_config.dataset_dir, 'MOTDataSet')
+ def update_config(self):
+ """update training config"""
+ if self.train_config.log_interval:
+ self.pdx_config.update_log_interval(self.train_config.log_interval)
+ if self.train_config.eval_interval:
+ self.pdx_config.update_eval_interval(self.train_config.eval_interval)
+ self._update_dataset()
+ if self.train_config.num_classes is not None:
+ self.pdx_config.update_num_class(self.train_config.num_classes)
+ if (
+ self.train_config.pretrain_weight_path
+ and self.train_config.pretrain_weight_path != ""
+ ):
+ self.pdx_config.update_pretrained_weights(
+ self.train_config.pretrain_weight_path
+ )
+ if self.train_config.batch_size is not None:
+ self.pdx_config.update_batch_size(self.train_config.batch_size)
+ if self.train_config.learning_rate is not None:
+ self.pdx_config.update_learning_rate(self.train_config.learning_rate)
+ if self.train_config.epochs_iters is not None:
+ self.pdx_config.update_epochs(self.train_config.epochs_iters)
+ epochs_iters = self.train_config.epochs_iters
+ else:
+ epochs_iters = self.pdx_config.get_epochs_iters()
+ if self.global_config.output is not None:
+ self.pdx_config.update_save_dir(self.global_config.output)
+ if "PicoDet" in self.global_config.model:
+ assigner_epochs = max(int(epochs_iters / 10), 1)
+ try:
+ self.pdx_config.update_static_assigner_epochs(assigner_epochs)
+ except Exception:
+ logging.info(
+ f"The model({self.global_config.model}) don't support to update_static_assigner_epochs!"
+ )
+ def get_train_kwargs(self) -> dict:
+ """get key-value arguments of model training function
+ Returns:
+ dict: the arguments of training function.
+ """
+ train_args = {"device": self.get_device()}
+ if (
+ self.train_config.resume_path is not None
+ and self.train_config.resume_path != ""
+ ):
+ train_args["resume_path"] = self.train_config.resume_path
+ train_args["dy2st"] = self.train_config.get("dy2st", False)
+ return train_args
from .object_det import DetModel, DetRunner, register
from .instance_seg import InstanceSegModel, InstanceSegRunner, register
+from .mot import MOTModel, MOTRunner, register
\ No newline at end of file
+# Runtime
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+save_dir: output
+print_flops: false
+print_params: false
+log_iter: 20
+snapshot_epoch: 2
+# Dataset
+metric: MOT
+num_classes: 1
+# Detection Dataset for training
+# TrainDataset: !COCODataSet
+# dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+# anno_path: annotations/train_half.json
+# image_dir: images/train
+# data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+TrainDataset: !COCODataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/train_half.json
+ image_dir: images/train
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+EvalDataset: !COCODataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/val_half.json
+ image_dir: images/train
+TestDataset: !ImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/val_half.json
+# MOTDataset for MOT evaluation and inference
+EvalMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ data_root: images/half # half
+ keep_ori_im: True # set as True in DeepSORT and ByteTrack
+TestMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ keep_ori_im: True # set True if save visualization images or video
+# Reader
+worker_num: 4
+eval_height: &eval_height 640
+eval_width: &eval_width 640
+eval_size: &eval_size [*eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - RandomDistort: {}
+ - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+ - RandomCrop: {}
+ - RandomFlip: {}
+ batch_transforms:
+ - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ - PadGT: {}
+ batch_size: 8
+ shuffle: true
+ drop_last: true
+ use_shared_memory: true
+ collate_batch: true
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+# add MOTReader for MOT evaluation and inference, note batch_size should be 1 in MOT
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+# Model
+architecture: ByteTrack
+norm_type: sync_bn
+use_ema: true
+ema_decay: 0.9998
+ema_black_list: ['proj_conv.weight']
+custom_black_list: ['reduce_mean']
+pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/ppyoloe_crn_l_300e_coco.pdparams
+depth_mult: 1.0
+width_mult: 1.0
+ detector: YOLOv3 # PPYOLOe version
+ reid: None
+ tracker: JDETracker
+det_weights: https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
+# https://bj.bcebos.com/v1/paddledet/models/mot/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
+reid_weights: None
+ backbone: CSPResNet
+ neck: CustomCSPPAN
+ yolo_head: PPYOLOEHead
+ post_process: ~
+ layers: [3, 6, 6, 3]
+ channels: [64, 128, 256, 512, 1024]
+ return_idx: [1, 2, 3]
+ use_large_stem: True
+ out_channels: [768, 384, 192]
+ stage_num: 1
+ block_num: 3
+ act: 'swish'
+ spp: true
+ fpn_strides: [32, 16, 8]
+ grid_cell_scale: 5.0
+ grid_cell_offset: 0.5
+ static_assigner_epoch: -1
+ use_varifocal_loss: True
+ loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
+ static_assigner:
+ name: ATSSAssigner
+ topk: 9
+ assigner:
+ name: TaskAlignedAssigner
+ topk: 13
+ alpha: 1.0
+ beta: 6.0
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 100
+ score_threshold: 0.1
+ nms_threshold: 0.4
+# BYTETracker
+ use_byte: True
+ match_thres: 0.9
+ conf_thres: 0.2
+ low_conf_thres: 0.1
+ min_box_area: 100
+ vertical_ratio: 1.6 # for pedestrian
+# Optimizer
+epoch: 36
+ base_lr: 0.001
+ schedulers:
+ - name: CosineDecay
+ max_epochs: 43
+ - name: LinearWarmup
+ start_factor: 0.001
+ epochs: 1
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0005
+ type: L2
+# Exporting the model
+ post_process: True # Whether post-processing is included in the network when export model.
+ nms: True # Whether NMS is included in the network when export model.
+ benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
+ fuse_conv_bn: False
diff --git a/paddlex/repo_apis/PaddleDetection_api/configs/DeepSORT_PP-YOLOE_ResNet.yaml b/paddlex/repo_apis/PaddleDetection_api/configs/DeepSORT_PP-YOLOE_ResNet.yaml
new file mode 100644
index 0000000000..c710cd47e2
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/configs/DeepSORT_PP-YOLOE_ResNet.yaml
@@ -0,0 +1,208 @@
+# Runtime
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+save_dir: output
+print_flops: false
+print_params: false
+use_ema: true
+log_iter: 20
+snapshot_epoch: 2
+# Dataset
+metric: MOT
+num_classes: 1
+# Detection Dataset for training
+TrainDataset: !COCODataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/train_half.json
+ image_dir: images/train
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd']
+EvalDataset: !COCODataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/val_half.json
+ image_dir: images/train
+TestDataset: !ImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ anno_path: annotations/val_half.json
+# MOTDataset for MOT evaluation and inference
+EvalMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ data_root: images/half
+ keep_ori_im: True # set as True in DeepSORT and ByteTrack
+TestMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets/MOT17
+ keep_ori_im: True # set True if save visualization images or video
+# Reader
+worker_num: 4
+eval_height: &eval_height 640
+eval_width: &eval_width 640
+eval_size: &eval_size [*eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - RandomDistort: {}
+ - RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
+ - RandomCrop: {}
+ - RandomFlip: {}
+ batch_transforms:
+ - BatchRandomResize: {target_size: [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768], random_size: True, random_interp: True, keep_ratio: False}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ - PadGT: {}
+ batch_size: 8
+ shuffle: true
+ drop_last: true
+ use_shared_memory: true
+ collate_batch: true
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 2
+ inputs_def:
+ image_shape: [3, *eval_height, *eval_width]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: *eval_size, keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+ inputs_def:
+ image_shape: [3, 640, 640]
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: [640, 640], keep_ratio: False, interp: 2}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+# Model
+ema_black_list: ['proj_conv.weight']
+norm_type: sync_bn
+use_ema: true
+ema_decay: 0.9998
+depth_mult: 1.0
+width_mult: 1.0
+det_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/ppyoloe_crn_l_36e_640x640_mot17half.pdparams
+reid_weights: https://paddledet.bj.bcebos.com/models/mot/deepsort/deepsort_resnet.pdparams
+architecture: DeepSORT
+pretrain_weights: None
+ backbone: CSPResNet
+ neck: CustomCSPPAN
+ yolo_head: PPYOLOEHead
+ post_process: ~
+ layers: [3, 6, 6, 3]
+ channels: [64, 128, 256, 512, 1024]
+ return_idx: [1, 2, 3]
+ use_large_stem: True
+ out_channels: [768, 384, 192]
+ stage_num: 1
+ block_num: 3
+ act: 'swish'
+ spp: true
+ fpn_strides: [32, 16, 8]
+ grid_cell_scale: 5.0
+ grid_cell_offset: 0.5
+ static_assigner_epoch: -1
+ use_varifocal_loss: True
+ loss_weight: {class: 1.0, iou: 2.5, dfl: 0.5}
+ static_assigner:
+ name: ATSSAssigner
+ topk: 9
+ assigner:
+ name: TaskAlignedAssigner
+ topk: 13
+ alpha: 1.0
+ beta: 6.0
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 100
+ score_threshold: 0.4
+ nms_threshold: 0.6
+ detector: YOLOv3 # PPYOLOe version
+ reid: ResNetEmbedding
+ tracker: DeepSORTTracker
+ model_name: "ResNet50"
+ input_size: [64, 192]
+ min_box_area: 0
+ vertical_ratio: -1
+ budget: 100
+ max_age: 70
+ n_init: 3
+ metric_type: cosine
+ matching_threshold: 0.2
+ max_iou_distance: 0.9
+ motion: KalmanFilter
+# Optimizer
+epoch: 36
+ base_lr: 0.001
+ schedulers:
+ - !CosineDecay
+ max_epochs: 43
+ - !LinearWarmup
+ start_factor: 0.001
+ epochs: 1
+ optimizer:
+ momentum: 0.9
+ type: Momentum
+ regularizer:
+ factor: 0.0005
+ type: L2
+# Exporting the model
+ post_process: True # Whether post-processing is included in the network when export model.
+ nms: True # Whether NMS is included in the network when export model.
+ benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
+ fuse_conv_bn: False
\ No newline at end of file
diff --git a/paddlex/repo_apis/PaddleDetection_api/configs/FairMOT-DLA-34.yaml b/paddlex/repo_apis/PaddleDetection_api/configs/FairMOT-DLA-34.yaml
new file mode 100644
index 0000000000..f8940d07cc
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/configs/FairMOT-DLA-34.yaml
@@ -0,0 +1,171 @@
+# Runtime
+use_gpu: true
+use_xpu: false
+use_mlu: false
+use_npu: false
+log_iter: 20
+save_dir: output
+snapshot_epoch: 1
+print_flops: false
+print_params: false
+use_ema: true
+# Dataset
+metric: MOT
+num_classes: 1
+# for MOT training
+ name: MOTDataSet
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ image_lists: ['mot17.train'] #
+ data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
+ name: MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ data_root: MOT17/images/train
+ keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
+ data_fields: ['image']
+ name: ImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ anno_path: MOT17/annotations/val_half.json
+# for MOT evaluation
+# If you want to change the MOT evaluation dataset, please modify 'data_root'
+EvalMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ data_root: MOT17/images/train
+ keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
+ data_fields: ['image']
+# for MOT video inference
+TestMOTDataset: !MOTImageFolder
+ dataset_dir: /mnt/yys/dataset/mot_datasets
+ keep_ori_im: True # set True if save visualization images or video
+# Reader
+worker_num: 4
+ inputs_def:
+ image_shape: [3, 608, 1088]
+ sample_transforms:
+ - Decode: {}
+ - RGBReverse: {}
+ - AugmentHSV: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - MOTRandomAffine: {reject_outside: False}
+ - RandomFlip: {}
+ - BboxXYXY2XYWH: {}
+ - NormalizeBox: {}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1]}
+ - RGBReverse: {}
+ - Permute: {}
+ batch_transforms:
+ - Gt2FairMOTTarget: {}
+ batch_size: 6
+ shuffle: True
+ drop_last: True
+ use_shared_memory: True
+# EvalReader:
+# sample_transforms:
+# - Decode: {}
+# - LetterBoxResize: {target_size: [608, 1088]}
+# - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+# - Permute: {}
+# batch_size: 1
+ sample_transforms:
+ - Decode: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+ inputs_def:
+ image_shape: [3, 608, 1088]
+ sample_transforms:
+ - Decode: {}
+ - LetterBoxResize: {target_size: [608, 1088]}
+ - NormalizeImage: {mean: [0, 0, 0], std: [1, 1, 1], is_scale: True}
+ - Permute: {}
+ batch_size: 1
+# Model
+architecture: FairMOT
+pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/fairmot_dla34_crowdhuman_pretrained.pdparams
+for_mot: True
+ detector: CenterNet
+ reid: FairMOTEmbeddingHead
+ loss: FairMOTLoss
+ tracker: JDETracker
+ backbone: DLA
+ neck: CenterNetDLAFPN
+ head: CenterNetHead
+ post_process: CenterNetPostProcess
+ down_ratio: 4
+ last_level: 5
+ out_channel: 0
+ dcn_v2: True
+ with_sge: False
+ head_planes: 256
+ prior_bias: -2.19
+ regress_ltrb: True
+ size_loss: 'L1'
+ loss_weight: {'heatmap': 1.0, 'size': 0.1, 'offset': 1.0, 'iou': 0.0}
+ add_iou: False
+ ch_head: 256
+ ch_emb: 128
+ max_per_img: 500
+ down_ratio: 4
+ regress_ltrb: True
+ conf_thres: 0.4
+ tracked_thresh: 0.4
+ metric_type: cosine
+ min_box_area: 200
+ vertical_ratio: 1.6 # for pedestrian
+# Optimizer
+epoch: 30
+ base_lr: 0.0001
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [20,]
+ use_warmup: False
+ optimizer:
+ type: Adam
+ regularizer: NULL
+# Exporting the model
+ post_process: True # Whether post-processing is included in the network when export model.
+ nms: True # Whether NMS is included in the network when export model.
+ benchmark: False # It is used to testing model performance, if set `True`, post-process and NMS will not be exported.
+ fuse_conv_bn: False
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/__init__.py b/paddlex/repo_apis/PaddleDetection_api/mot/__init__.py
new file mode 100644
index 0000000000..325812d306
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/__init__.py
@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .model import MOTModel
+from .runner import MOTRunner
+from . import register
+from .official_categories import official_categories
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/config.py b/paddlex/repo_apis/PaddleDetection_api/mot/config.py
new file mode 100644
index 0000000000..622afb5e83
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/config.py
@@ -0,0 +1,505 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ...base import BaseConfig
+from ....utils.misc import abspath
+from ....utils import logging
+from ..config_helper import PPDetConfigMixin
+class MOTConfig(BaseConfig, PPDetConfigMixin):
+ """MOTConfig"""
+ def load(self, config_path: str):
+ """load the config from config file
+ Args:
+ config_path (str): the config file path.
+ """
+ dict_ = self.load_config_literally(config_path)
+ self.reset_from_dict(dict_)
+ def dump(self, config_path: str):
+ """dump the config
+ Args:
+ config_path (str): the path to save dumped config.
+ """
+ self.dump_literal_config(config_path, self._dict)
+ def update(self, dict_like_obj: list):
+ """update self from dict
+ Args:
+ dict_like_obj (list): the list of pairs that contain key and value.
+ """
+ self.update_from_dict(dict_like_obj, self._dict)
+ def update_dataset(
+ self,
+ dataset_path: str,
+ dataset_type: str = None,
+ *,
+ data_fields: list[str] = ['image', 'gt_bbox', 'gt_class', 'is_crowd'],
+ image_dir: str = "images/train",
+ train_anno_path: str = "annotations/train_half.json", # annotations/instance_train.json
+ val_anno_path: str = "annotations/val_half.json",
+ test_anno_path: str = "annotations/val_half.json",
+ ):
+ """update dataset settings
+ Args:
+ dataset_path (str): the root path fo dataset.
+ dataset_type (str, optional): the dataset type. Defaults to None.
+ data_fields (list[str], optional): the data fields in dataset. Defaults to None.
+ image_dir (str, optional): the images file directory that relative to `dataset_path`. Defaults to "images".
+ train_anno_path (str, optional): the train annotations file that relative to `dataset_path`.
+ Defaults to "annotations/instance_train.json".
+ val_anno_path (str, optional): the validation annotations file that relative to `dataset_path`.
+ Defaults to "annotations/instance_val.json".
+ test_anno_path (str, optional): the test annotations file that relative to `dataset_path`.
+ Defaults to "annotations/instance_val.json".
+ Raises:
+ ValueError: the `dataset_type` error.
+ """
+ dataset_path = abspath(dataset_path)
+ ds_cfg = {}
+ dataset_names = ['TrainDataset', 'EvalDataset', 'TestDataset', 'EvalMOTDataset', 'TestMOTDataset']
+ for data_name in dataset_names:
+ ds_cfg[data_name]={}
+ ds_cfg[data_name]['dataset_dir'] = dataset_path
+ self.update(ds_cfg)
+ def _make_dataset_config(
+ self,
+ dataset_root_path: str,
+ data_fields: list[str,] = None,
+ image_dir: str = "images",
+ train_anno_path: str = "annotations/instance_train.json",
+ val_anno_path: str = "annotations/instance_val.json",
+ test_anno_path: str = "annotations/instance_val.json",
+ ) -> dict:
+ """construct the dataset config that meets the format requirements
+ Args:
+ dataset_root_path (str): the root directory of dataset.
+ data_fields (list[str,], optional): the data field. Defaults to None.
+ image_dir (str, optional): _description_. Defaults to "images".
+ train_anno_path (str, optional): _description_. Defaults to "annotations/instance_train.json".
+ val_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
+ test_anno_path (str, optional): _description_. Defaults to "annotations/instance_val.json".
+ Returns:
+ dict: the dataset config.
+ """
+ data_fields = (
+ ["image", "gt_bbox", "gt_class", "is_crowd"]
+ if data_fields is None
+ else data_fields
+ )
+ return {
+ "TrainDataset": {
+ "name": "COCODetDataset",
+ "image_dir": image_dir,
+ "anno_path": train_anno_path,
+ "dataset_dir": dataset_root_path,
+ "data_fields": data_fields,
+ },
+ "EvalDataset": {
+ "name": "COCODetDataset",
+ "image_dir": image_dir,
+ "anno_path": val_anno_path,
+ "dataset_dir": dataset_root_path,
+ },
+ "TestDataset": {
+ "name": "ImageFolder",
+ "anno_path": test_anno_path,
+ "dataset_dir": dataset_root_path,
+ },
+ }
+ def update_ema(
+ self,
+ use_ema: bool,
+ ema_decay: float = 0.9999,
+ ema_decay_type: str = "exponential",
+ ema_filter_no_grad: bool = True,
+ ):
+ """update EMA setting
+ Args:
+ use_ema (bool): whether or not to use EMA
+ ema_decay (float, optional): value of EMA decay. Defaults to 0.9999.
+ ema_decay_type (str, optional): type of EMA decay. Defaults to "exponential".
+ ema_filter_no_grad (bool, optional): whether or not to filter the parameters
+ that been set to stop gradient and are not batch norm parameters. Defaults to True.
+ """
+ self.update(
+ {
+ "use_ema": use_ema,
+ "ema_decay": ema_decay,
+ "ema_decay_type": ema_decay_type,
+ "ema_filter_no_grad": ema_filter_no_grad,
+ }
+ )
+ def update_learning_rate(self, learning_rate: float):
+ """update learning rate
+ Args:
+ learning_rate (float): the learning rate value to set.
+ """
+ self.LearningRate["base_lr"] = learning_rate
+ def update_warmup_steps(self, warmup_steps: int):
+ """update warmup steps
+ Args:
+ warmup_steps (int): the warmup steps value to set.
+ """
+ schedulers = self.LearningRate["schedulers"]
+ for sch in schedulers:
+ key = "name" if "name" in sch else "_type_"
+ if sch[key] == "LinearWarmup":
+ sch["steps"] = warmup_steps
+ sch["epochs_first"] = False
+ def update_warmup_enable(self, use_warmup: bool):
+ """whether or not to enable learning rate warmup
+ Args:
+ use_warmup (bool): `True` is enable learning rate warmup and `False` is disable.
+ """
+ schedulers = self.LearningRate["schedulers"]
+ for sch in schedulers:
+ if "use_warmup" in sch:
+ sch["use_warmup"] = use_warmup
+ def update_cossch_epoch(self, max_epochs: int):
+ """update max epochs of cosine learning rate scheduler
+ Args:
+ max_epochs (int): the max epochs value.
+ """
+ schedulers = self.LearningRate["schedulers"]
+ for sch in schedulers:
+ key = "name" if "name" in sch else "_type_"
+ if sch[key] == "CosineDecay":
+ sch["max_epochs"] = max_epochs
+ def update_milestone(self, milestones: list[int]):
+ """update milstone of `PiecewiseDecay` learning scheduler
+ Args:
+ milestones (list[int]): the list of milestone values of `PiecewiseDecay` learning scheduler.
+ """
+ schedulers = self.LearningRate["schedulers"]
+ for sch in schedulers:
+ key = "name" if "name" in sch else "_type_"
+ if sch[key] == "PiecewiseDecay":
+ sch["milestones"] = milestones
+ def update_batch_size(self, batch_size: int, mode: str = "train"):
+ """update batch size setting
+ Args:
+ batch_size (int): the batch size number to set.
+ mode (str, optional): the mode that to be set batch size, must be one of 'train', 'eval', 'test'.
+ Defaults to 'train'.
+ Raises:
+ ValueError: mode error.
+ """
+ assert mode in (
+ "train",
+ "eval",
+ "test",
+ ), "mode ({}) should be train, eval or test".format(mode)
+ if mode == "train":
+ self.TrainReader["batch_size"] = batch_size
+ elif mode == "eval":
+ self.EvalReader["batch_size"] = batch_size
+ else:
+ self.TestReader["batch_size"] = batch_size
+ def update_epochs(self, epochs: int):
+ """update epochs setting
+ Args:
+ epochs (int): the epochs number value to set
+ """
+ self.update({"epoch": epochs})
+ def update_device(self, device_type: str):
+ """update device setting
+ Args:
+ device (str): the running device to set
+ """
+ if device_type.lower() == "gpu":
+ self["use_gpu"] = True
+ elif device_type.lower() == "xpu":
+ self["use_xpu"] = True
+ self["use_gpu"] = False
+ elif device_type.lower() == "npu":
+ self["use_npu"] = True
+ self["use_gpu"] = False
+ elif device_type.lower() == "mlu":
+ self["use_mlu"] = True
+ self["use_gpu"] = False
+ else:
+ assert device_type.lower() == "cpu"
+ self["use_gpu"] = False
+ def update_save_dir(self, save_dir: str):
+ """update directory to save outputs
+ Args:
+ save_dir (str): the directory to save outputs.
+ """
+ self["save_dir"] = abspath(save_dir)
+ def update_log_interval(self, log_interval: int):
+ """update log interval(steps)
+ Args:
+ log_interval (int): the log interval value to set.
+ """
+ self.update({"log_iter": log_interval})
+ def update_eval_interval(self, eval_interval: int):
+ """update eval interval(epochs)
+ Args:
+ eval_interval (int): the eval interval value to set.
+ """
+ self.update({"snapshot_epoch": eval_interval})
+ def update_save_interval(self, save_interval: int):
+ """update eval interval(epochs)
+ Args:
+ save_interval (int): the save interval value to set.
+ """
+ self.update({"snapshot_epoch": save_interval})
+ def update_log_ranks(self, device):
+ """update log ranks
+ Args:
+ device (str): the running device to set
+ """
+ log_ranks = device.split(":")[1]
+ self.update({"log_ranks": log_ranks})
+ def update_print_mem_info(self, print_mem_info: bool):
+ """setting print memory info"""
+ assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
+ self.update({"print_mem_info": f"{print_mem_info}"})
+ def update_shared_memory(self, shared_memeory: bool):
+ """update shared memory setting of train and eval dataloader
+ Args:
+ shared_memeory (bool): whether or not to use shared memory
+ """
+ assert isinstance(shared_memeory, bool), "shared_memeory should be a bool"
+ self.update({"print_mem_info": f"{shared_memeory}"})
+ def update_shuffle(self, shuffle: bool):
+ """update shuffle setting of train and eval dataloader
+ Args:
+ shuffle (bool): whether or not to shuffle the data
+ """
+ assert isinstance(shuffle, bool), "shuffle should be a bool"
+ self.update({"TrainReader": {"shuffle": shuffle}})
+ self.update({"EvalReader": {"shuffle": shuffle}})
+ def update_weights(self, weight_path: str):
+ """update model weight
+ Args:
+ weight_path (str): the path to weight file of model.
+ """
+ self["weights"] = weight_path
+ def update_pretrained_weights(self, pretrain_weights: str):
+ """update pretrained weight path
+ Args:
+ pretrained_model (str): the local path or url of pretrained weight file to set.
+ """
+ if not pretrain_weights.startswith(
+ "http://"
+ ) and not pretrain_weights.startswith("https://"):
+ pretrain_weights = abspath(pretrain_weights)
+ self["pretrain_weights"] = pretrain_weights
+ def update_num_class(self, num_classes: int):
+ """update classes number
+ Args:
+ num_classes (int): the classes number value to set.
+ """
+ self["num_classes"] = num_classes
+ if 'CenterNet' in self.model_name:
+ for i in range(len(self['TrainReader']['sample_transforms'])):
+ if 'Gt2CenterNetTarget' in self['TrainReader']['sample_transforms'][i].keys():
+ self['TrainReader']['sample_transforms'][i]['Gt2CenterNetTarget']['num_classes'] = num_classes
+ def update_random_size(self, randomsize: list[list[int, int]]):
+ """update `target_size` of `BatchRandomResize` op in TestReader
+ Args:
+ randomsize (list[list[int, int]]): the list of different size scales.
+ """
+ self.TestReader["batch_transforms"]["BatchRandomResize"][
+ "target_size"
+ ] = randomsize
+ def update_num_workers(self, num_workers: int):
+ """update workers number of train and eval dataloader
+ Args:
+ num_workers (int): the value of train and eval dataloader workers number to set.
+ """
+ self["worker_num"] = num_workers
+ def _recursively_set(self, config: dict, update_dict: dict):
+ """recursively set config
+ Args:
+ config (dict): the original config.
+ update_dict (dict): to be updated paramenters and its values
+ Example:
+ self._recursively_set(self.HybridEncoder, {'encoder_layer': {'dim_feedforward': 2048}})
+ """
+ assert isinstance(update_dict, dict)
+ for key in update_dict:
+ if key not in config:
+ logging.info(f"A new filed of config to set found: {repr(key)}.")
+ config[key] = update_dict[key]
+ elif not isinstance(update_dict[key], dict):
+ config[key] = update_dict[key]
+ else:
+ self._recursively_set(config[key], update_dict[key])
+ def update_static_assigner_epochs(self, static_assigner_epochs: int):
+ """update static assigner epochs value
+ Args:
+ static_assigner_epochs (int): the value of static assigner epochs
+ """
+ assert "PicoHeadV2" in self
+ self.PicoHeadV2["static_assigner_epoch"] = static_assigner_epochs
+ def update_HybridEncoder(self, update_dict: dict):
+ """update the HybridEncoder neck setting
+ Args:
+ update_dict (dict): the HybridEncoder setting.
+ """
+ assert "HybridEncoder" in self
+ self._recursively_set(self.HybridEncoder, update_dict)
+ def get_epochs_iters(self) -> int:
+ """get epochs
+ Returns:
+ int: the epochs value, i.e., `Global.epochs` in config.
+ """
+ return self.epoch
+ def get_log_interval(self) -> int:
+ """get log interval(steps)
+ Returns:
+ int: the log interval value, i.e., `Global.print_batch_step` in config.
+ """
+ self.log_iter
+ def get_eval_interval(self) -> int:
+ """get eval interval(epochs)
+ Returns:
+ int: the eval interval value, i.e., `Global.eval_interval` in config.
+ """
+ self.snapshot_epoch
+ def get_save_interval(self) -> int:
+ """get save interval(epochs)
+ Returns:
+ int: the save interval value, i.e., `Global.save_interval` in config.
+ """
+ self.snapshot_epoch
+ def get_learning_rate(self) -> float:
+ """get learning rate
+ Returns:
+ float: the learning rate value, i.e., `Optimizer.lr.learning_rate` in config.
+ """
+ return self.LearningRate["base_lr"]
+ def get_batch_size(self, mode="train") -> int:
+ """get batch size
+ Args:
+ mode (str, optional): the mode that to be get batch size value, must be one of 'train', 'eval', 'test'.
+ Defaults to 'train'.
+ Returns:
+ int: the batch size value of `mode`, i.e., `DataLoader.{mode}.sampler.batch_size` in config.
+ """
+ if mode == "train":
+ return self.TrainReader["batch_size"]
+ elif mode == "eval":
+ return self.EvalReader["batch_size"]
+ elif mode == "test":
+ return self.TestReader["batch_size"]
+ else:
+ raise (f"Unknown mode: {repr(mode)}")
+ def get_qat_epochs_iters(self) -> int:
+ """get qat epochs
+ Returns:
+ int: the epochs value.
+ """
+ return self.epoch // 2.0
+ def get_qat_learning_rate(self) -> float:
+ """get qat learning rate
+ Returns:
+ float: the learning rate value.
+ """
+ return self.LearningRate["base_lr"] // 2.0
+ def get_train_save_dir(self) -> str:
+ """get the directory to save output
+ Returns:
+ str: the directory to save output
+ """
+ return self.save_dir
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/model.py b/paddlex/repo_apis/PaddleDetection_api/mot/model.py
new file mode 100644
index 0000000000..70ee585c54
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/model.py
@@ -0,0 +1,433 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import json
+from ...base import BaseModel
+from ...base.utils.arg import CLIArgument
+from ...base.utils.subprocess import CompletedProcess
+from ....utils.device import parse_device
+from ....utils.misc import abspath
+from ....utils import logging
+from .config import MOTConfig
+from .official_categories import official_categories
+class MOTModel(BaseModel):
+ """Object Detection Model"""
+ def train(
+ self,
+ batch_size: int = None,
+ learning_rate: float = None,
+ epochs_iters: int = None,
+ ips: str = None,
+ device: str = "gpu",
+ resume_path: str = None,
+ dy2st: bool = False,
+ amp: str = "OFF",
+ num_workers: int = None,
+ use_vdl: bool = True,
+ save_dir: str = None,
+ **kwargs,
+ ) -> CompletedProcess:
+ """train self
+ Args:
+ batch_size (int, optional): the train batch size value. Defaults to None.
+ learning_rate (float, optional): the train learning rate value. Defaults to None.
+ epochs_iters (int, optional): the train epochs value. Defaults to None.
+ ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
+ device (str, optional): the running device. Defaults to 'gpu'.
+ resume_path (str, optional): the checkpoint file path to resume training. Train from scratch if it is set
+ to None. Defaults to None.
+ dy2st (bool, optional): Enable dynamic to static. Defaults to False.
+ amp (str, optional): the amp settings. Defaults to 'OFF'.
+ num_workers (int, optional): the workers number. Defaults to None.
+ use_vdl (bool, optional): enable VisualDL. Defaults to True.
+ save_dir (str, optional): the directory path to save train output. Defaults to None.
+ Returns:
+ CompletedProcess: the result of training subprocess execution.
+ """
+ config = self.config.copy()
+ cli_args = []
+ if batch_size is not None:
+ config.update_batch_size(batch_size, "train")
+ if learning_rate is not None:
+ config.update_learning_rate(learning_rate)
+ if epochs_iters is not None:
+ config.update_epochs(epochs_iters)
+ config.update_cossch_epoch(epochs_iters)
+ device_type, _ = parse_device(device)
+ config.update_device(device_type)
+ if resume_path is not None:
+ assert resume_path.endswith(
+ ".pdparams"
+ ), "resume_path should be endswith .pdparam"
+ resume_dir = resume_path[0:-9]
+ cli_args.append(CLIArgument("--resume", resume_dir))
+ if dy2st:
+ cli_args.append(CLIArgument("--to_static"))
+ if num_workers is not None:
+ config.update_num_workers(num_workers)
+ if save_dir is None:
+ save_dir = abspath(config.get_train_save_dir())
+ else:
+ save_dir = abspath(save_dir)
+ config.update_save_dir(save_dir)
+ if use_vdl:
+ cli_args.append(CLIArgument("--use_vdl", use_vdl))
+ cli_args.append(CLIArgument("--vdl_log_dir", save_dir))
+ do_eval = kwargs.pop("do_eval", True)
+ enable_ce = kwargs.pop("enable_ce", None)
+ profile = kwargs.pop("profile", None)
+ if profile is not None:
+ cli_args.append(CLIArgument("--profiler_options", profile))
+ # Benchmarking mode settings
+ benchmark = kwargs.pop("benchmark", None)
+ if benchmark is not None:
+ envs = benchmark.get("env", None)
+ amp = benchmark.get("amp", None)
+ do_eval = benchmark.get("do_eval", False)
+ num_workers = benchmark.get("num_workers", None)
+ config.update_log_ranks(device)
+ config.update_shuffle(benchmark.get("shuffle", False))
+ config.update_shared_memory(benchmark.get("shared_memory", True))
+ config.update_print_mem_info(benchmark.get("print_mem_info", True))
+ if num_workers is not None:
+ config.update_num_workers(num_workers)
+ if amp == "O1":
+ # TODO: ppdet only support ampO1
+ cli_args.append(CLIArgument("--amp"))
+ if envs is not None:
+ for env_name, env_value in envs.items():
+ os.environ[env_name] = str(env_value)
+ # set seed to 0 for benchmark mode by enable_ce
+ cli_args.append(CLIArgument("--enable_ce", True))
+ else:
+ if amp != "OFF" and amp is not None:
+ # TODO: consider amp is O1 or O2 in ppdet
+ cli_args.append(CLIArgument("--amp"))
+ if enable_ce is not None:
+ cli_args.append(CLIArgument("--enable_ce", enable_ce))
+ # PDX related settings
+ if device_type in ["npu", "xpu", "mlu"]:
+ uniform_output_enabled = False
+ else:
+ uniform_output_enabled = True
+ config.update({"uniform_output_enabled": uniform_output_enabled})
+ config.update({"pdx_model_name": self.name})
+ hpi_config_path = self.model_info.get("hpi_config_path", None)
+ if hpi_config_path:
+ hpi_config_path = hpi_config_path.as_posix()
+ config.update({"hpi_config_path": hpi_config_path})
+ self._assert_empty_kwargs(kwargs)
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ return self.runner.train(
+ config_path, cli_args, device, ips, save_dir, do_eval=do_eval
+ )
+ def evaluate(
+ self,
+ weight_path: str,
+ batch_size: int = None,
+ ips: bool = None,
+ device: bool = "gpu",
+ amp: bool = "OFF",
+ num_workers: int = None,
+ **kwargs,
+ ) -> CompletedProcess:
+ """evaluate self using specified weight
+ Args:
+ weight_path (str): the path of model weight file to be evaluated.
+ batch_size (int, optional): the batch size value in evaluating. Defaults to None.
+ ips (str, optional): the ip addresses of nodes when using distribution. Defaults to None.
+ device (str, optional): the running device. Defaults to 'gpu'.
+ amp (str, optional): the AMP setting. Defaults to 'OFF'.
+ num_workers (int, optional): the workers number in evaluating. Defaults to None.
+ Returns:
+ CompletedProcess: the result of evaluating subprocess execution.
+ """
+ config = self.config.copy()
+ cli_args = []
+ weight_path = abspath(weight_path)
+ config.update_weights(weight_path)
+ if batch_size is not None:
+ config.update_batch_size(batch_size, "eval")
+ device_type, device_ids = parse_device(device)
+ if len(device_ids) > 1:
+ raise ValueError(
+ f"multi-{device_type} evaluation is not supported. Please use a single {device_type}."
+ )
+ config.update_device(device_type)
+ if amp != "OFF":
+ # TODO: consider amp is O1 or O2 in ppdet
+ cli_args.append(CLIArgument("--amp"))
+ if num_workers is not None:
+ config.update_num_workers(num_workers)
+ self._assert_empty_kwargs(kwargs)
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ cp = self.runner.evaluate(config_path, cli_args, device, ips, config.dict['architecture'])
+ return cp
+ def predict(
+ self,
+ input_path: str,
+ weight_path: str,
+ device: str = "gpu",
+ save_dir: str = None,
+ **kwargs,
+ ) -> CompletedProcess:
+ """predict using specified weight
+ Args:
+ weight_path (str): the path of model weight file used to predict.
+ input_path (str): the path of image file to be predicted.
+ device (str, optional): the running device. Defaults to 'gpu'.
+ save_dir (str, optional): the directory path to save predict output. Defaults to None.
+ Returns:
+ CompletedProcess: the result of predicting subprocess execution.
+ """
+ config = self.config.copy()
+ cli_args = []
+ input_path = abspath(input_path)
+ if os.path.isfile(input_path):
+ cli_args.append(CLIArgument("--infer_img", input_path))
+ else:
+ cli_args.append(CLIArgument("--infer_dir", input_path))
+ if "infer_list" in kwargs:
+ infer_list = abspath(kwargs.get("infer_list"))
+ cli_args.append(CLIArgument("--infer_list", infer_list))
+ if "visualize" in kwargs:
+ cli_args.append(CLIArgument("--visualize", kwargs["visualize"]))
+ if "save_results" in kwargs:
+ cli_args.append(CLIArgument("--save_results", kwargs["save_results"]))
+ if "save_threshold" in kwargs:
+ cli_args.append(CLIArgument("--save_threshold", kwargs["save_threshold"]))
+ if "rtn_im_file" in kwargs:
+ cli_args.append(CLIArgument("--rtn_im_file", kwargs["rtn_im_file"]))
+ weight_path = abspath(weight_path)
+ config.update_weights(weight_path)
+ device_type, _ = parse_device(device)
+ config.update_device(device_type)
+ if save_dir is not None:
+ save_dir = abspath(save_dir)
+ cli_args.append(CLIArgument("--output_dir", save_dir))
+ self._assert_empty_kwargs(kwargs)
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ return self.runner.predict(config_path, cli_args, device)
+ def export(self, weight_path: str, save_dir: str, **kwargs) -> CompletedProcess:
+ """export the dynamic model to static model
+ Args:
+ weight_path (str): the model weight file path that used to export.
+ save_dir (str): the directory path to save export output.
+ Returns:
+ CompletedProcess: the result of exporting subprocess execution.
+ """
+ config = self.config.copy()
+ cli_args = []
+ device = kwargs.pop("device", None)
+ if device:
+ device_type, _ = parse_device(device)
+ config.update_device(device_type)
+ if not weight_path.startswith("http"):
+ weight_path = abspath(weight_path)
+ config.update_weights(weight_path)
+ save_dir = abspath(save_dir)
+ cli_args.append(CLIArgument("--output_dir", save_dir))
+ input_shape = kwargs.pop("input_shape", None)
+ if input_shape is not None:
+ cli_args.append(
+ CLIArgument("-o", f"TestReader.inputs_def.image_shape={input_shape}")
+ )
+ use_trt = kwargs.pop("use_trt", None)
+ if use_trt is not None:
+ cli_args.append(CLIArgument("-o", f"trt={bool(use_trt)}"))
+ exclude_nms = kwargs.pop("exclude_nms", None)
+ if exclude_nms is not None:
+ cli_args.append(CLIArgument("-o", f"exclude_nms={bool(exclude_nms)}"))
+ # PDX related settings
+ config.update({"pdx_model_name": self.name})
+ hpi_config_path = self.model_info.get("hpi_config_path", None)
+ if hpi_config_path:
+ hpi_config_path = hpi_config_path.as_posix()
+ config.update({"hpi_config_path": hpi_config_path})
+ if self.name in official_categories.keys():
+ anno_val_file = abspath(
+ os.path.join(
+ config.TestDataset["dataset_dir"], config.TestDataset["anno_path"]
+ )
+ )
+ if anno_val_file == None or (not os.path.isfile(anno_val_file)):
+ categories = official_categories[self.name]
+ temp_anno = {"images": [], "annotations": [], "categories": categories}
+ with self._create_new_val_json_file() as anno_file:
+ json.dump(temp_anno, open(anno_file, "w"))
+ config.update(
+ {"TestDataset": {"dataset_dir": "", "anno_path": anno_file}}
+ )
+ logging.warning(
+ f"{self.name} does not have validate annotations, use {anno_file} default instead."
+ )
+ self._assert_empty_kwargs(kwargs)
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ return self.runner.export(config_path, cli_args, None)
+ self._assert_empty_kwargs(kwargs)
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ return self.runner.export(config_path, cli_args, None)
+ def infer(
+ self,
+ model_dir: str,
+ input_path: str,
+ device: str = "gpu",
+ save_dir: str = None,
+ **kwargs,
+ ):
+ """predict image using infernece model
+ Args:
+ model_dir (str): the directory path of inference model files that would use to predict.
+ input_path (str): the path of image that would be predict.
+ device (str, optional): the running device. Defaults to 'gpu'.
+ save_dir (str, optional): the directory path to save output. Defaults to None.
+ Returns:
+ CompletedProcess: the result of infering subprocess execution.
+ """
+ model_dir = abspath(model_dir)
+ input_path = abspath(input_path)
+ if save_dir is not None:
+ save_dir = abspath(save_dir)
+ cli_args = []
+ cli_args.append(CLIArgument("--model_dir", model_dir))
+ cli_args.append(CLIArgument("--image_file", input_path))
+ if save_dir is not None:
+ cli_args.append(CLIArgument("--output_dir", save_dir))
+ device_type, _ = parse_device(device)
+ cli_args.append(CLIArgument("--device", device_type))
+ self._assert_empty_kwargs(kwargs)
+ return self.runner.infer(cli_args, device)
+ def compression(
+ self,
+ weight_path: str,
+ batch_size: int = None,
+ learning_rate: float = None,
+ epochs_iters: int = None,
+ device: str = None,
+ use_vdl: bool = True,
+ save_dir: str = None,
+ **kwargs,
+ ) -> CompletedProcess:
+ """compression model
+ Args:
+ weight_path (str): the path to weight file of model.
+ batch_size (int, optional): the batch size value of compression training. Defaults to None.
+ learning_rate (float, optional): the learning rate value of compression training. Defaults to None.
+ epochs_iters (int, optional): the epochs or iters of compression training. Defaults to None.
+ device (str, optional): the device to run compression training. Defaults to 'gpu'.
+ use_vdl (bool, optional): whether or not to use VisualDL. Defaults to True.
+ save_dir (str, optional): the directory to save output. Defaults to None.
+ Returns:
+ CompletedProcess: the result of compression subprocess execution.
+ """
+ weight_path = abspath(weight_path)
+ if save_dir is None:
+ save_dir = self.config["save_dir"]
+ save_dir = abspath(save_dir)
+ config = self.config.copy()
+ cps_config = DetConfig(
+ self.name, config_path=self.model_info["auto_compression_config_path"]
+ )
+ train_cli_args = []
+ export_cli_args = []
+ cps_config.update_pretrained_weights(weight_path)
+ if batch_size is not None:
+ cps_config.update_batch_size(batch_size, "train")
+ if learning_rate is not None:
+ cps_config.update_learning_rate(learning_rate)
+ if epochs_iters is not None:
+ cps_config.update_epochs(epochs_iters)
+ if device is not None:
+ device_type, _ = parse_device(device)
+ config.update_device(device_type)
+ if save_dir is not None:
+ save_dir = abspath(config.get_train_save_dir())
+ else:
+ save_dir = abspath(save_dir)
+ cps_config.update_save_dir(save_dir)
+ if use_vdl:
+ train_cli_args.append(CLIArgument("--use_vdl", use_vdl))
+ train_cli_args.append(CLIArgument("--vdl_log_dir", save_dir))
+ export_cli_args.append(
+ CLIArgument("--output_dir", os.path.join(save_dir, "export"))
+ )
+ with self._create_new_config_file() as config_path:
+ config.dump(config_path)
+ # TODO: refactor me
+ cps_config_path = config_path[0:-4] + "_compression" + config_path[-4:]
+ cps_config.dump(cps_config_path)
+ train_cli_args.append(CLIArgument("--slim_config", cps_config_path))
+ export_cli_args.append(CLIArgument("--slim_config", cps_config_path))
+ self._assert_empty_kwargs(kwargs)
+ self.runner.compression(
+ config_path, train_cli_args, export_cli_args, device, save_dir
+ )
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/official_categories.py b/paddlex/repo_apis/PaddleDetection_api/mot/official_categories.py
new file mode 100644
index 0000000000..7ae51cd0bc
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/official_categories.py
@@ -0,0 +1,14 @@
+official_categories = {
+'PP-YOLOE-L_human': [{"name": "pedestrian", "id": 0}],
+'PP-YOLOE-S_human': [{"name": "pedestrian", "id": 0}],
+'PP-YOLOE-S_vehicle': [{"name": "vehicle", "id": 0}],
+'PP-YOLOE-L_vehicle': [{"name": "vehicle", "id": 0}],
+'PP-ShiTuV2_det': [{"name": "mainbody", "id": 0}],
+'PicoDet_layout_1x': [{"name": "Text", "id": 0}, {"name": "Title", "id": 1}, {"name": "List", "id": 2}, {"name": "Table", "id": 3}, {"name": "Figure", "id": 4}],
+'PicoDet-L_layout_3cls': [{"name": "image", "id": 0}, {"name": "table", "id": 1}, {"name": "seal", "id": 2}],
+'RT-DETR-H_layout_3cls': [{"name": "image", "id": 0}, {"name": "table", "id": 1}, {"name": "seal", "id": 2}],
+'RT-DETR-H_layout_17cls': [{"name": "paragraph_title", "id": 0}, {"name": "image", "id": 1}, {"name": "text", "id": 2}, {"name": "number", "id": 3}, {"name": "abstract", "id": 4}, {"name": "content", "id": 5}, {"name": "figure_title", "id": 6}, {"name": "formula", "id": 7}, {"name": "table", "id": 8}, {"name": "tabke_title", "id": 9}, {"name":"reference", "id": 10}, {"name": "doc_title", "id": 11}, {"name": "footnote", "id": 12}, {"name": "header", "id": 13}, {"name": "algorithm", "id": 14}, {"name": "footer", "id": 15}, {"name": "seal", "id": 16}],
+'PP-YOLOE_plus_SOD-S': [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
+'PP-YOLOE_plus_SOD-L': [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
+'PP-YOLOE_plus_SOD-largesize-L': [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
\ No newline at end of file
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/register.py b/paddlex/repo_apis/PaddleDetection_api/mot/register.py
new file mode 100644
index 0000000000..9e8958c210
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/register.py
@@ -0,0 +1,86 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import os.path as osp
+from pathlib import Path
+from ...base.register import register_model_info, register_suite_info
+from .model import MOTModel
+from .config import MOTConfig
+from .runner import MOTRunner
+PDX_CONFIG_DIR = osp.abspath(osp.join(osp.dirname(__file__), "..", "configs"))
+HPI_CONFIG_DIR = Path(__file__).parent.parent.parent.parent / "utils" / "hpi_configs"
+ {
+ "suite_name": "MOT",
+ "model": MOTModel,
+ "runner": MOTRunner,
+ "config": MOTConfig,
+ "runner_root_path": REPO_ROOT_PATH,
+ }
+################ Models Using Universal Config ################
+ {
+ "model_name": "ByteTrack_PP-YOLOE_L",
+ "suite": "MOT",
+ "config_path": osp.join(PDX_CONFIG_DIR, "ByteTrack_PP-YOLOE_L.yaml"),
+ "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+ "supported_dataset_types": ["COCODetDataset"],
+ "supported_train_opts": {
+ "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+ "dy2st": False,
+ "amp": ["OFF"],
+ },
+ "hpi_config_path": HPI_CONFIG_DIR / "ByteTrack_PP-YOLOE_L.yaml",
+ }
+ {
+ "model_name": "DeepSORT_PP-YOLOE_ResNet",
+ "suite": "MOT",
+ "config_path": osp.join(PDX_CONFIG_DIR, "DeepSORT_PP-YOLOE_ResNet.yaml"),
+ "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+ "supported_dataset_types": ["COCODetDataset"],
+ "supported_train_opts": {
+ "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+ "dy2st": False,
+ "amp": ["OFF"],
+ },
+ "hpi_config_path": HPI_CONFIG_DIR / "DeepSORT_PP-YOLOE_ResNet.yaml",
+ }
+ {
+ "model_name": "FairMOT-DLA-34",
+ "suite": "MOT",
+ "config_path": osp.join(PDX_CONFIG_DIR, "FairMOT-DLA-34.yaml"),
+ "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
+ "supported_dataset_types": ["MOTDataSet"],
+ "supported_train_opts": {
+ "device": ["cpu", "gpu_nxcx", "xpu", "npu", "mlu"],
+ "dy2st": False,
+ "amp": ["OFF"],
+ },
+ "hpi_config_path": HPI_CONFIG_DIR / "FairMOT-DLA-34.yaml",
+ }
\ No newline at end of file
diff --git a/paddlex/repo_apis/PaddleDetection_api/mot/runner.py b/paddlex/repo_apis/PaddleDetection_api/mot/runner.py
new file mode 100644
index 0000000000..4e502bdc37
--- /dev/null
+++ b/paddlex/repo_apis/PaddleDetection_api/mot/runner.py
@@ -0,0 +1,267 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+# http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import tempfile
+from ...base import BaseRunner
+from ...base.utils.arg import CLIArgument, gather_opts_args
+from ...base.utils.subprocess import CompletedProcess
+class MOTRunner(BaseRunner):
+ """MOTRunner"""
+ def train(
+ self,
+ config_path: str,
+ cli_args: list,
+ device: str,
+ ips: str,
+ save_dir: str,
+ do_eval=True,
+ ) -> CompletedProcess:
+ """train model
+ Args:
+ config_path (str): the config file path used to train.
+ cli_args (list): the additional parameters.
+ device (str): the training device.
+ ips (str): the ip addresses of nodes when using distribution.
+ save_dir (str): the directory path to save training output.
+ do_eval (bool, optional): whether or not to evaluate model during training. Defaults to True.
+ Returns:
+ CompletedProcess: the result of training subprocess execution.
+ """
+ args, env = self.distributed(device, ips, log_dir=save_dir)
+ cli_args = self._gather_opts_args(cli_args)
+ cmd = [*args, "tools/train.py"]
+ if do_eval:
+ cmd.append("--eval")
+ cmd.extend(["--config", config_path, *cli_args])
+ return self.run_cmd(
+ cmd,
+ env=env,
+ switch_wdir=True,
+ echo=True,
+ silent=False,
+ capture_output=True,
+ log_path=self._get_train_log_path(save_dir),
+ )
+ def evaluate(
+ self, config_path: str, cli_args: list, device: str, ips: str, arch: str
+ ) -> CompletedProcess:
+ """run model evaluating
+ Args:
+ config_path (str): the config file path used to evaluate.
+ cli_args (list): the additional parameters.
+ device (str): the evaluating device.
+ ips (str): the ip addresses of nodes when using distribution.
+ Returns:
+ CompletedProcess: the result of evaluating subprocess execution.
+ """
+ args, env = self.distributed(device, ips)
+ cli_args = self._gather_opts_args(cli_args)
+ cmd = [*args, "tools/eval_mot.py", "--config", config_path, "--scaled=True", *cli_args] # "--training_weights"
+ if arch != 'DeepSORT':
+ """
+ deepsort has no need of training, so there is no training_weights trainined by yourself.
+ """
+ cmd.append("--training_weights")
+ cp = self.run_cmd(
+ cmd, env=env, switch_wdir=True, echo=True, silent=False, capture_output=True
+ )
+ if cp.returncode == 0:
+ metric_dict = _extract_mot_eval_metrics(cp.stdout)
+ cp.metrics = metric_dict
+ return cp
+ def predict(
+ self, config_path: str, cli_args: list, device: str
+ ) -> CompletedProcess:
+ """run predicting using dynamic mode
+ Args:
+ config_path (str): the config file path used to predict.
+ cli_args (list): the additional parameters.
+ device (str): unused.
+ Returns:
+ CompletedProcess: the result of predicting subprocess execution.
+ """
+ # `device` unused
+ cli_args = self._gather_opts_args(cli_args)
+ cmd = [self.python, "tools/infer.py", "-c", config_path, *cli_args]
+ return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+ def export(self, config_path: str, cli_args: list, device: str) -> CompletedProcess:
+ """run exporting
+ Args:
+ config_path (str): the path of config file used to export.
+ cli_args (list): the additional parameters.
+ device (str): unused.
+ save_dir (str, optional): the directory path to save exporting output. Defaults to None.
+ Returns:
+ CompletedProcess: the result of exporting subprocess execution.
+ """
+ # `device` unused
+ cli_args = self._gather_opts_args(cli_args)
+ cmd = [
+ self.python,
+ "tools/export_model.py",
+ "--for_fd",
+ "-c",
+ config_path,
+ *cli_args,
+ ]
+ cp = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+ return cp
+ def infer(self, cli_args: list, device: str) -> CompletedProcess:
+ """run predicting using inference model
+ Args:
+ cli_args (list): the additional parameters.
+ device (str): unused.
+ Returns:
+ CompletedProcess: the result of infering subprocess execution.
+ """
+ # `device` unused
+ cmd = [self.python, "deploy/python/infer.py", "--use_fd_format", *cli_args]
+ return self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+ def compression(
+ self,
+ config_path: str,
+ train_cli_args: list,
+ export_cli_args: list,
+ device: str,
+ train_save_dir: str,
+ ) -> CompletedProcess:
+ """run compression model
+ Args:
+ config_path (str): the path of config file used to predict.
+ train_cli_args (list): the additional training parameters.
+ export_cli_args (list): the additional exporting parameters.
+ device (str): the running device.
+ train_save_dir (str): the directory path to save output.
+ Returns:
+ CompletedProcess: the result of compression subprocess execution.
+ """
+ args, env = self.distributed(device, log_dir=train_save_dir)
+ train_cli_args = self._gather_opts_args(train_cli_args)
+ cmd = [*args, "tools/train.py", "-c", config_path, *train_cli_args]
+ cp_train = self.run_cmd(
+ cmd,
+ env=env,
+ switch_wdir=True,
+ echo=True,
+ silent=False,
+ capture_output=True,
+ log_path=self._get_train_log_path(train_save_dir),
+ )
+ cps_weight_path = os.path.join(train_save_dir, "model_final")
+ export_cli_args.append(CLIArgument("-o", f"weights={cps_weight_path}"))
+ export_cli_args = self._gather_opts_args(export_cli_args)
+ cmd = [
+ self.python,
+ "tools/export_model.py",
+ "--for_fd",
+ "-c",
+ config_path,
+ *export_cli_args,
+ ]
+ cp_export = self.run_cmd(cmd, switch_wdir=True, echo=True, silent=False)
+ return cp_train, cp_export
+ def _gather_opts_args(self, args):
+ """_gather_opts_args"""
+ return gather_opts_args(args, "-o")
+def _extract_eval_metrics(stdout):
+ """extract evaluation metrics from training log
+ Args:
+ stdout (str): the training log
+ Returns:
+ dict: the training metric
+ """
+ import re
+ pattern = r".*\(AP\)\s*@\[\s*IoU=0\.50:0\.95\s*\|\s*area=\s*all\s\|\smaxDets=\s*\d+\s\]\s*=\s*[0-1]?\.[0-9]{3}$"
+ key = "AP"
+ metric_dict = dict()
+ pattern = re.compile(pattern)
+ # TODO: Use lazy version to make it more efficient
+ lines = stdout.splitlines()
+ metric_dict[key] = 0
+ for line in lines:
+ match = pattern.search(line)
+ if match:
+ metric_dict[key] = float(match.group(0)[-5:])
+ return metric_dict
+def _extract_mot_eval_metrics(stdout):
+ """extract evaluation metrics from training log
+ Args:
+ stdout (str): the training log
+ Returns:
+ dict: the training metric
+ """
+ import re
+ pattern = r"OVERALL"
+ key = "MOTA"
+ metric_dict = dict()
+ pattern = re.compile(pattern)
+ # TODO: Use lazy version to make it more efficient
+ lines = stdout.splitlines()
+ metric_dict[key] = 0
+ for line in lines:
+ match = pattern.search(line)
+ if match:
+ outs = line.split(' ')
+ outs = [i for i in outs if i != ""]
+ metric_dict[key] = float(outs[-5].split('%')[0])/100
+ print(metric_dict)
+ return metric_dict
diff --git a/paddlex/utils/hpi_configs/ByteTrack_PP-YOLOE_L.yaml b/paddlex/utils/hpi_configs/ByteTrack_PP-YOLOE_L.yaml
new file mode 100644
index 0000000000..37ce8e69b2
--- /dev/null
+++ b/paddlex/utils/hpi_configs/ByteTrack_PP-YOLOE_L.yaml
@@ -0,0 +1,65 @@
+ backend_config:
+ onnx_runtime:
+ cpu_num_threads: 8
+ openvino:
+ cpu_num_threads: 8
+ paddle_infer:
+ cpu_num_threads: 8
+ enable_log_info: false
+ paddle_tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ enable_log_info: false
+ max_batch_size: null
+ tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ max_batch_size: null
+ selected_backends:
+ cpu: onnx_runtime
+ gpu: tensorrt
+ supported_backends:
+ cpu:
+ - paddle_infer
+ - openvino
+ - onnx_runtime
+ gpu:
+ - paddle_infer
+ - paddle_tensorrt
+ - onnx_runtime
+ - tensorrt
diff --git a/paddlex/utils/hpi_configs/DeepSORT_PP-YOLOE_ResNet.yaml b/paddlex/utils/hpi_configs/DeepSORT_PP-YOLOE_ResNet.yaml
new file mode 100644
index 0000000000..37ce8e69b2
--- /dev/null
+++ b/paddlex/utils/hpi_configs/DeepSORT_PP-YOLOE_ResNet.yaml
@@ -0,0 +1,65 @@
+ backend_config:
+ onnx_runtime:
+ cpu_num_threads: 8
+ openvino:
+ cpu_num_threads: 8
+ paddle_infer:
+ cpu_num_threads: 8
+ enable_log_info: false
+ paddle_tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ enable_log_info: false
+ max_batch_size: null
+ tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ max_batch_size: null
+ selected_backends:
+ cpu: onnx_runtime
+ gpu: tensorrt
+ supported_backends:
+ cpu:
+ - paddle_infer
+ - openvino
+ - onnx_runtime
+ gpu:
+ - paddle_infer
+ - paddle_tensorrt
+ - onnx_runtime
+ - tensorrt
diff --git a/paddlex/utils/hpi_configs/FairMOT-DLA-34.yaml b/paddlex/utils/hpi_configs/FairMOT-DLA-34.yaml
new file mode 100644
index 0000000000..37ce8e69b2
--- /dev/null
+++ b/paddlex/utils/hpi_configs/FairMOT-DLA-34.yaml
@@ -0,0 +1,65 @@
+ backend_config:
+ onnx_runtime:
+ cpu_num_threads: 8
+ openvino:
+ cpu_num_threads: 8
+ paddle_infer:
+ cpu_num_threads: 8
+ enable_log_info: false
+ paddle_tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ enable_log_info: false
+ max_batch_size: null
+ tensorrt:
+ dynamic_shapes:
+ im_shape:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ image:
+ - []
+ - []
+ - []
+ scale_factor:
+ - - 1
+ - 2
+ - - 1
+ - 2
+ - - 1
+ - 2
+ max_batch_size: null
+ selected_backends:
+ cpu: onnx_runtime
+ gpu: tensorrt
+ supported_backends:
+ cpu:
+ - paddle_infer
+ - openvino
+ - onnx_runtime
+ gpu:
+ - paddle_infer
+ - paddle_tensorrt
+ - onnx_runtime
+ - tensorrt