diff --git a/README.md b/README.md index aeec441..8716893 100644 --- a/README.md +++ b/README.md @@ -18,31 +18,59 @@ The TorchYolo library aims to support all YOLO models(like YOLOv5, YOLOv6, YOLOv pip install torchyolo ``` ### Prediction +First download the [default_config.yaml](https://github.com/kadirnar/torchyolo/blob/tracker/torchyolo/default_config.yaml) file. + ```python from torchyolo import YoloHub -predictor = YoloHub( - model_type="yolov5", - model_path="yolov5s.pt", - device='cpu', - image_size=640 -) -predictor.conf_thres = 0.25 -predictor.iou_thres = 0.45 -predictor.save = True -predictor.show = False -image = "data/highway.jpg" -result = predictor.predict(image) + +model = YoloHub(config_path="torchyolo/default_config.yaml") +result = model.predict(tracker=True) +``` + +### Configuration +```yaml +TRACKER_CONFIG: + # The name of the tracker + TRACKER_TYPE: NORFAIR_TRACK + # The path of the config file + CONFIG_PATH: torchyolo/configs/tracker/norfair_track.yaml + # The path of the model file + WEIGHT_PATH: osnet_x1_0_msmt17.pt + + +DETECTOR_CONFIG: + # The name of the detector + DETECTOR_TYPE: yolov8 # yolov7 + # The threshold for the IOU score + IOU_TH: 0.45 + # The threshold for the confidence score + CONF_TH: 0.25 + # The size of the image + IMAGE_SIZE: 640 + # The path of the weight file + MODEL_PATH: yolov8s.pt + # The device to run the detector + DEVICE: cuda:0 + # F16 precision + HALF: False + + +DATA_CONFIG: + # The path of the input video + INPUT_PATH: ../test.mp4 + # The path of the output video + OUTPUT_PATH: Results + # Save the video + SHOW: False + # Show the video + SAVE: True ``` ## Model Architecture ```python from torchyolo import YoloHub -model = YoloHub( - model_type="yolov8", - model_path="yolov8n.pt", - device="cuda:0", - image_size=640) +model = YoloHub(config_path="torchyolo/default_config.yaml") result = model.view_model(file_format="pdf") ``` @@ -62,14 +90,6 @@ A part of the code is borrowed from [SAHI](https://github.com/obss/sahi). Many t ### Citation ```bibtex -@article{li2022yolov6, - title={YOLOv6: A single-stage object detection framework for industrial applications}, - author={Li, Chuyi and Li, Lulu and Jiang, Hongliang and Weng, Kaiheng and Geng, Yifei and Li, Liang and Ke, Zaidan and Li, Qingyuan and Cheng, Meng and Nie, Weiqiang and others}, - journal={arXiv preprint arXiv:2209.02976}, - year={2022} -} -``` -```bibtex @article{wang2022yolov7, title={{YOLOv7}: Trainable bag-of-freebies sets new state-of-the-art for real-time object detectors}, author={Wang, Chien-Yao and Bochkovskiy, Alexey and Liao, Hong-Yuan Mark}, @@ -77,14 +97,6 @@ A part of the code is borrowed from [SAHI](https://github.com/obss/sahi). Many t year={2022} } ``` -```bibtex - @article{yolox2021, - title={YOLOX: Exceeding YOLO Series in 2021}, - author={Ge, Zheng and Liu, Songtao and Wang, Feng and Li, Zeming and Sun, Jian}, - journal={arXiv preprint arXiv:2107.08430}, - year={2021} -} -``` ```bibtex @software{glenn_jocher_2020_4154370, author = {Glenn Jocher and,Alex Stoken and,Jirka Borovec and,NanoCode012 and,ChristopherSTAN and,Liu Changyu and,Laughing and,tkianai and,Adam Hogan and,lorenzomammana and,yxNONG and,AlexWang1900 and,Laurentiu Diaconu and,Marc and,wanghaoyang0106 and,ml5ah and,Doug and,Francisco Ingham and,Frederik and,Guilhen and,Hatovix and,Jake Poznanski and,Jiacong Fang and,Lijun Yu δΊŽεŠ›ε†› and,changyu98 and,Mingyu Wang and,Naman Gupta and,Osama Akhtar and,PetrDvoracek and,Prashant Rai}, @@ -98,3 +110,38 @@ A part of the code is borrowed from [SAHI](https://github.com/obss/sahi). Many t url= {https://doi.org/10.5281/zenodo.4154370} } ``` +```bibtex +@article{cao2022observation, + title={Observation-Centric SORT: Rethinking SORT for Robust Multi-Object Tracking}, + author={Cao, Jinkun and Weng, Xinshuo and Khirodkar, Rawal and Pang, Jiangmiao and Kitani, Kris}, + journal={arXiv preprint arXiv:2203.14360}, + year={2022} +} +``` +```bibtex +@article{zhang2022bytetrack, + title={ByteTrack: Multi-Object Tracking by Associating Every Detection Box}, + author={Zhang, Yifu and Sun, Peize and Jiang, Yi and Yu, Dongdong and Weng, Fucheng and Yuan, Zehuan and Luo, Ping and Liu, Wenyu and Wang, Xinggang}, + booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, + year={2022} +} +``` +```bibtex +@article{du2022strongsort, + title={Strongsort: Make deepsort great again}, + author={Du, Yunhao and Song, Yang and Yang, Bo and Zhao, Yanyun}, + journal={arXiv preprint arXiv:2202.13514}, + year={2022} +} +``` +```bibtex +@inproceedings{Bewley2016_sort, + author={Bewley, Alex and Ge, Zongyuan and Ott, Lionel and Ramos, Fabio and Upcroft, Ben}, + booktitle={2016 IEEE International Conference on Image Processing (ICIP)}, + title={Simple online and realtime tracking}, + year={2016}, + pages={3464-3468}, + keywords={Benchmark testing;Complexity theory;Detectors;Kalman filters;Target tracking;Visualization;Computer Vision;Data Association;Detection;Multiple Object Tracking}, + doi={10.1109/ICIP.2016.7533003} +} +``` \ No newline at end of file diff --git a/torchyolo/__init__.py b/torchyolo/__init__.py index 7d671fc..9ab8acf 100644 --- a/torchyolo/__init__.py +++ b/torchyolo/__init__.py @@ -1,3 +1,3 @@ from torchyolo.predict import YoloHub -__version__ = "0.2.3" +__version__ = "0.3.0" diff --git a/torchyolo/automodel.py b/torchyolo/automodel.py index d55f262..7754400 100644 --- a/torchyolo/automodel.py +++ b/torchyolo/automodel.py @@ -1,64 +1,24 @@ -from typing import Optional +from torchyolo.utils.config_utils import get_config MODEL_TYPE_TO_MODEL_CLASS_NAME = { "yolov5": "Yolov5DetectionModel", - "yolov6": "Yolov6DetectionModel", "yolov7": "Yolov7DetectionModel", "yolov8": "Yolov8DetectionModel", - "yolox": "YoloxDetectionModel", } class AutoDetectionModel: def from_pretrained( - model_type: str, - model_path: Optional[str] = None, - config_path: Optional[str] = None, - device: Optional[str] = None, - confidence_threshold: float = 0.3, - iou_threshold: float = 0.5, - image_size: int = None, - **kwargs, + config_path: str, ): - """ - Loads a DetectionModel from given path. - Args: - model_type: str - Name of the detection framework (example: "yolov5", "mmdet", "detectron2") - model_path: str - Path of the detection model (ex. 'model.pt') - config_path: str - Path of the config file (ex. 'mmdet/configs/cascade_rcnn_r50_fpn_1x.py') - device: str - Device, "cpu" or "cuda:0" - mask_threshold: float - Value to threshold mask pixels, should be between 0 and 1 - confidence_threshold: float - All predictions with score < confidence_threshold will be discarded - category_mapping: dict: str to str - Mapping from category id (str) to category name (str) e.g. {"1": "pedestrian"} - category_remapping: dict: str to int - Remap category ids based on category names, after performing inference e.g. {"car": 3} - load_at_init: bool - If True, automatically loads the model at initalization - image_size: int - Inference input size. - Returns: - Returns an instance of a DetectionModel - Raises: - ImportError: If given {model_type} framework is not installed - """ + config = get_config(config_path) + model_type = config.DETECTOR_CONFIG.DETECTOR_TYPE model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type] DetectionModel = getattr( __import__(f"torchyolo.modelhub.{model_type}", fromlist=[model_class_name]), model_class_name ) + return DetectionModel( - model_path=model_path, config_path=config_path, - device=device, - confidence_threshold=confidence_threshold, - iou_threshold=iou_threshold, - image_size=image_size, - **kwargs, ) diff --git a/torchyolo/configs/tracker/__init__.py b/torchyolo/configs/tracker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchyolo/configs/tracker/byte_track.yaml b/torchyolo/configs/tracker/byte_track.yaml new file mode 100644 index 0000000..5a111a4 --- /dev/null +++ b/torchyolo/configs/tracker/byte_track.yaml @@ -0,0 +1,7 @@ +BYTE_TRACK: + # The name of the sort + TRACK_BUFFER: 25 + # The buffer for the track + FRAME_RATE: 30 + # The frame rate of the video + \ No newline at end of file diff --git a/torchyolo/configs/tracker/norfair_track.yaml b/torchyolo/configs/tracker/norfair_track.yaml new file mode 100644 index 0000000..a4eaefe --- /dev/null +++ b/torchyolo/configs/tracker/norfair_track.yaml @@ -0,0 +1,10 @@ +NORFAIR_TRACK: + DISTANCE_FUNCTION: "frobenius" # mean_manhattan, mean_euclidean, iou, iou_opt + DISTANCE_THRESHOLD: 500 + HIT_COUNTER_MAX: 15 + INITIALIZATION_DELAY: null + POINTWISE_HIT_COUNTER_MAX: 4 + DETECTION_THRESHOLD: 0 + PAST_DETECTIONS_LENGTH: 4 + REID_DISTANCE_THRESHOLD: 0 + REID_HIT_COUNTER_MAX: null diff --git a/torchyolo/configs/tracker/oc_sort.yaml b/torchyolo/configs/tracker/oc_sort.yaml new file mode 100644 index 0000000..92ef538 --- /dev/null +++ b/torchyolo/configs/tracker/oc_sort.yaml @@ -0,0 +1,18 @@ +OC_SORT: + # The name of the sort + CONF_THRESHOLD: 0.05 + # The threshold for the confidence score + IOU_THRESHOLD: 0.3 + # The threshold for the IOU score + MAX_AGE: 30 + # The maximum age of the track + MIN_HITS: 3 + # The minimum number of hits for the track + DELTA_T: 3 + # The time interval between two frames + ASSO_FUNC: "iou" # giou, ciou, diou, ct_dist + # The association function + INERTIA: 0.2 + # The inertia of the track + USE_BYTE: False + # Whether to use byte as the unit of the bounding box diff --git a/torchyolo/configs/tracker/sort_track.yaml b/torchyolo/configs/tracker/sort_track.yaml new file mode 100644 index 0000000..41652c5 --- /dev/null +++ b/torchyolo/configs/tracker/sort_track.yaml @@ -0,0 +1,6 @@ +SORT_TRACK: + # The name of the sort + MAX_AGE: 1 + # The maximum number of frames to keep alive a track without associated detections + MIN_HITS: 3 + # The minimum number of associated detections before track initialization \ No newline at end of file diff --git a/torchyolo/configs/tracker/strong_sort.yaml b/torchyolo/configs/tracker/strong_sort.yaml new file mode 100644 index 0000000..8893dd8 --- /dev/null +++ b/torchyolo/configs/tracker/strong_sort.yaml @@ -0,0 +1,9 @@ +STRONG_SORT: + ECC: True # activate camera motion compensation + MC_LAMBDA: 0.995 # matching with both appearance (1 - MC_LAMBDA) and motion cost + EMA_ALPHA: 0.9 # updates appearance state in an exponential moving average manner + MAX_DIST: 0.2 # The matching threshold. Samples with larger distance are considered an invalid match + MAX_IOU_DISTANCE: 0.7 # Gating threshold. Associations with cost larger than this value are disregarded. + MAX_AGE: 30 # Maximum number of missed misses before a track is deleted + N_INIT: 3 # Number of frames that a track remains in initialization phase + NN_BUDGET: 100 # Maximum size of the appearance descriptors gallery diff --git a/torchyolo/configs/yolov6/Arial.ttf b/torchyolo/configs/yolov6/Arial.ttf deleted file mode 100644 index ab68fb1..0000000 Binary files a/torchyolo/configs/yolov6/Arial.ttf and /dev/null differ diff --git a/torchyolo/configs/yolov6/coco.yaml b/torchyolo/configs/yolov6/coco.yaml deleted file mode 100644 index 28faa6d..0000000 --- a/torchyolo/configs/yolov6/coco.yaml +++ /dev/null @@ -1,20 +0,0 @@ -# COCO 2017 dataset http://cocodataset.org -train: ../coco/images/train2017 # 118287 images -val: ../coco/images/val2017 # 5000 images -test: ../coco/images/test2017 -anno_path: ../coco/annotations/instances_val2017.json -# number of classes -nc: 80 -# whether it is coco dataset, only coco dataset should be set to True. -is_coco: True - -# class names -names: [ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', - 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', - 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', - 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', - 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', - 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', - 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', - 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', - 'hair drier', 'toothbrush' ] diff --git a/torchyolo/configs/yolox/yolov3.py b/torchyolo/configs/yolox/yolov3.py deleted file mode 100644 index e8e68a4..0000000 --- a/torchyolo/configs/yolox/yolov3.py +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -# Copyright (c) Megvii, Inc. and its affiliates. - -import os - -import torch.nn as nn -from yolox.exp import Exp as MyExp - - -class Exp(MyExp): - def __init__(self): - super(Exp, self).__init__() - self.depth = 1.0 - self.width = 1.0 - self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] - - def get_model(self, sublinear=False): - def init_yolo(M): - for m in M.modules(): - if isinstance(m, nn.BatchNorm2d): - m.eps = 1e-3 - m.momentum = 0.03 - - if "model" not in self.__dict__: - from yolox.models import YOLOFPN, YOLOX, YOLOXHead - - backbone = YOLOFPN() - head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu") - self.model = YOLOX(backbone, head) - self.model.apply(init_yolo) - self.model.head.initialize_biases(1e-2) - - return self.model diff --git a/torchyolo/configs/yolox/yolox_l.py b/torchyolo/configs/yolox/yolox_l.py deleted file mode 100644 index 50833ca..0000000 --- a/torchyolo/configs/yolox/yolox_l.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -# Copyright (c) Megvii, Inc. and its affiliates. - -import os - -from yolox.exp import Exp as MyExp - - -class Exp(MyExp): - def __init__(self): - super(Exp, self).__init__() - self.depth = 1.0 - self.width = 1.0 - self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] diff --git a/torchyolo/configs/yolox/yolox_m.py b/torchyolo/configs/yolox/yolox_m.py deleted file mode 100644 index 9666a31..0000000 --- a/torchyolo/configs/yolox/yolox_m.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -# Copyright (c) Megvii, Inc. and its affiliates. - -import os - -from yolox.exp import Exp as MyExp - - -class Exp(MyExp): - def __init__(self): - super(Exp, self).__init__() - self.depth = 0.67 - self.width = 0.75 - self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] diff --git a/torchyolo/configs/yolox/yolox_nano.py b/torchyolo/configs/yolox/yolox_nano.py deleted file mode 100644 index 4c3944f..0000000 --- a/torchyolo/configs/yolox/yolox_nano.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -# Copyright (c) Megvii, Inc. and its affiliates. - -import os - -import torch.nn as nn -from yolox.exp import Exp as MyExp - - -class Exp(MyExp): - def __init__(self): - super(Exp, self).__init__() - self.depth = 0.33 - self.width = 0.25 - self.input_size = (416, 416) - self.random_size = (10, 20) - self.mosaic_scale = (0.5, 1.5) - self.test_size = (416, 416) - self.mosaic_prob = 0.5 - self.enable_mixup = False - self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] - - def get_model(self, sublinear=False): - def init_yolo(M): - for m in M.modules(): - if isinstance(m, nn.BatchNorm2d): - m.eps = 1e-3 - m.momentum = 0.03 - - if "model" not in self.__dict__: - from yolox.models import YOLOPAFPN, YOLOX, YOLOXHead - - in_channels = [256, 512, 1024] - # NANO model use depthwise = True, which is main difference. - backbone = YOLOPAFPN( - self.depth, - self.width, - in_channels=in_channels, - act=self.act, - depthwise=True, - ) - head = YOLOXHead(self.num_classes, self.width, in_channels=in_channels, act=self.act, depthwise=True) - self.model = YOLOX(backbone, head) - - self.model.apply(init_yolo) - self.model.head.initialize_biases(1e-2) - return self.model diff --git a/torchyolo/configs/yolox/yolox_s.py b/torchyolo/configs/yolox/yolox_s.py deleted file mode 100644 index abb6a8b..0000000 --- a/torchyolo/configs/yolox/yolox_s.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -# Copyright (c) Megvii, Inc. and its affiliates. - -import os - -from yolox.exp import Exp as MyExp - - -class Exp(MyExp): - def __init__(self): - super(Exp, self).__init__() - self.depth = 0.33 - self.width = 0.50 - self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] diff --git a/torchyolo/configs/yolox/yolox_tiny.py b/torchyolo/configs/yolox/yolox_tiny.py deleted file mode 100644 index 5220de2..0000000 --- a/torchyolo/configs/yolox/yolox_tiny.py +++ /dev/null @@ -1,20 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -# Copyright (c) Megvii, Inc. and its affiliates. - -import os - -from yolox.exp import Exp as MyExp - - -class Exp(MyExp): - def __init__(self): - super(Exp, self).__init__() - self.depth = 0.33 - self.width = 0.375 - self.input_size = (416, 416) - self.mosaic_scale = (0.5, 1.5) - self.random_size = (10, 20) - self.test_size = (416, 416) - self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] - self.enable_mixup = False diff --git a/torchyolo/configs/yolox/yolox_x.py b/torchyolo/configs/yolox/yolox_x.py deleted file mode 100644 index ac498a1..0000000 --- a/torchyolo/configs/yolox/yolox_x.py +++ /dev/null @@ -1,15 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -# Copyright (c) Megvii, Inc. and its affiliates. - -import os - -from yolox.exp import Exp as MyExp - - -class Exp(MyExp): - def __init__(self): - super(Exp, self).__init__() - self.depth = 1.33 - self.width = 1.25 - self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] diff --git a/torchyolo/default_config.yaml b/torchyolo/default_config.yaml new file mode 100644 index 0000000..3c2e9b9 --- /dev/null +++ b/torchyolo/default_config.yaml @@ -0,0 +1,38 @@ +TRACKER_CONFIG: + # The name of the tracker + TRACKER_TYPE: NORFAIR_TRACK + # The path of the config file + CONFIG_PATH: torchyolo/configs/tracker/norfair_track.yaml + # The path of the model file + WEIGHT_PATH: osnet_x1_0_msmt17.pt + + +DETECTOR_CONFIG: + # The name of the detector + DETECTOR_TYPE: yolov8 # yolov7 + # The threshold for the IOU score + IOU_TH: 0.45 + # The threshold for the confidence score + CONF_TH: 0.25 + # The size of the image + IMAGE_SIZE: 640 + # The path of the weight file + MODEL_PATH: yolov8s.pt + # The device to run the detector + DEVICE: cuda:0 + # F16 precision + HALF: False + # The path of the yolov6 label file + YOLOV6_YAML_FILE: torchyolo/configs/yolov6/coco.yaml + # The path of the yolovx config file + YOLOX_CONFIG_PATH: configs.yolox.yolox_s + +DATA_CONFIG: + # The path of the input video + INPUT_PATH: ../test.mp4 + # The path of the output video + OUTPUT_PATH: Results + # Save the video + SHOW: False + # Show the video + SAVE: True diff --git a/torchyolo/modelhub/basemodel.py b/torchyolo/modelhub/basemodel.py deleted file mode 100644 index 38d4234..0000000 --- a/torchyolo/modelhub/basemodel.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Optional - -import numpy as np -import torch - - -class YoloDetectionModel: - def __init__( - self, - model_path: Optional[str] = None, - config_path: Optional[str] = None, - device: Optional[str] = None, - confidence_threshold: float = 0.3, - iou_threshold: float = 0.5, - image_size: int = 640, - ): - """ - Init object detection model. - Args: - model_path: str - Path for the instance segmentation model weight - config_path: str - Path for the mmdetection instance segmentation model config file - device: str - Torch device, "cpu" or "cuda" - iou_threshold: float - All predictions with IoU < iou_threshold will be discarded - confidence_threshold: float - All predictions with score < confidence_threshold will be discarded - image_size: int - Inference input size. - """ - self.model_path = model_path - self.config_path = config_path - self.device = device - self.iou_threshold = iou_threshold - self.confidence_threshold = confidence_threshold - self.image_size = image_size - self.yaml_file = "torchyolo/configs/yolov6/coco.yaml" - if self.save: - self.save_path = "output" - self.output_file_name = "prediction_visual" - - # automatically load model if load_at_init is True - self.load_model() - - if self.device is None: - self.device = "cuda" if torch.cuda.is_available() else "cpu" - - def load_model(self): - """ - This function should be implemented in a way that detection model - should be initialized and set to self.model. - (self.model_path, self.config_path, and self.device should be utilized) - """ - raise NotImplementedError() - - def predict(self, image: np.ndarray, yaml_file: str = None): - """ - This function should be implemented in a way that detection model - should be initialized and set to self.model. - (self.model_path, self.config_path, and self.device should be utilized) - """ - raise NotImplementedError() diff --git a/torchyolo/modelhub/yolov5.py b/torchyolo/modelhub/yolov5.py index eaec77e..e833bf2 100644 --- a/torchyolo/modelhub/yolov5.py +++ b/torchyolo/modelhub/yolov5.py @@ -2,49 +2,99 @@ import yolov5 from tqdm import tqdm -from torchyolo.modelhub.basemodel import YoloDetectionModel +from torchyolo.tracker_zoo import load_tracker +from torchyolo.utils.config_utils import get_config from torchyolo.utils.dataset import LoadData, create_video_writer from torchyolo.utils.object_vis import video_vis -class Yolov5DetectionModel(YoloDetectionModel): +class Yolov5DetectionModel: + def __init__(self, config_path: str): + self.load_config(config_path) + self.load_model() + + def load_config(self, config_path: str): + self.config_path = config_path + config = get_config(config_path) + self.input_path = config.DATA_CONFIG.INPUT_PATH + self.output_path = config.DATA_CONFIG.OUTPUT_PATH + self.model_type = config.DETECTOR_CONFIG.DETECTOR_TYPE + self.model_path = config.DETECTOR_CONFIG.MODEL_PATH + self.device = config.DETECTOR_CONFIG.DEVICE + self.conf = config.DETECTOR_CONFIG.CONF_TH + self.iou = config.DETECTOR_CONFIG.IOU_TH + self.image_size = config.DETECTOR_CONFIG.IMAGE_SIZE + self.save = config.DATA_CONFIG.SAVE + self.show = config.DATA_CONFIG.SHOW + def load_model(self): model = yolov5.load(self.model_path, device=self.device) - model.conf = self.confidence_threshold - model.iou = self.iou_threshold + model.conf = self.conf + model.iou = self.iou self.model = model - def predict(self, input_path, yaml_file=None, save=False, show=False): + def predict(self, tracker=True): + tracker_module = load_tracker(self.config_path) + config = get_config(self.config_path) + input_path = config.DATA_CONFIG.INPUT_PATH + + tracker_outputs = [None] dataset = LoadData(input_path) video_writer = create_video_writer(video_path=input_path, output_path="output") for img_src, img_path, vid_cap in tqdm(dataset): - results = self.model(img_src, augment=False) - for index, prediction in enumerate(results.pred): - for pred in prediction.cpu().detach().numpy(): - x1, y1, x2, y2 = ( - int(pred[0]), - int(pred[1]), - int(pred[2]), - int(pred[3]), + results = self.model(img_src) + for image_id, prediction in enumerate(results.pred): + if tracker: + tracker_outputs[image_id] = tracker_module.update(prediction.cpu(), img_src) + for output in tracker_outputs[image_id]: + bbox, track_id, category_id, score = ( + output[:4], + int(output[4]), + output[5], + output[6], + ) + category_name = self.model.names[int(category_id)] + label = f"Id:{track_id} {category_name} {float(score):.2f}" + + if self.save or self.show: + frame = video_vis( + bbox=bbox, + label=label, + frame=img_src, + object_id=int(category_id), + ) + if self.save: + video_writer.write(frame) + + if self.show: + cv2.imshow("frame", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + else: + for pred in prediction.cpu().detach().numpy(): + x1, y1, x2, y2 = ( + int(pred[0]), + int(pred[1]), + int(pred[2]), + int(pred[3]), + ) + bbox = [x1, y1, x2, y2] + score = pred[4] + category_name = self.model.names[int(pred[5])] + category_id = int(pred[5]) + label = f"{category_name} {score:.2f}" + + frame = video_vis( + bbox=bbox, + label=label, + frame=img_src, + object_id=category_id, ) - bbox = [x1, y1, x2, y2] - score = pred[4] - category_name = self.model.names[int(pred[5])] - category_id = int(pred[5]) - label = f"{category_name} {score:.2f}" - - frame = video_vis( - bbox=bbox, - label=label, - frame=img_src, - object_id=category_id, - ) - if save: - video_writer.write(frame) - - if show: - cv2.imshow("frame", frame) - if cv2.waitKey(1) & 0xFF == ord("q"): - break - - video_writer.release() + if self.save: + video_writer.write(frame) + + if self.show: + cv2.imshow("frame", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break diff --git a/torchyolo/modelhub/yolov6.py b/torchyolo/modelhub/yolov6.py deleted file mode 100644 index edba772..0000000 --- a/torchyolo/modelhub/yolov6.py +++ /dev/null @@ -1,60 +0,0 @@ -import cv2 -from sahi.prediction import ObjectPrediction, PredictionResult -from sahi.utils.cv import visualize_object_predictions -from yolov6 import YOLOV6 - -from torchyolo.modelhub.basemodel import YoloDetectionModel - - -class Yolov6DetectionModel(YoloDetectionModel): - def load_model(self): - model = YOLOV6(self.model_path, device=self.device) - model.torchyolo = True - model.font_path = "torchyolo/configs/yolov6/Arial.ttf" - model.conf = self.confidence_threshold - model.iou = self.iou_threshold - model.save_img = self.save - model.show_img = self.show - self.model = model - - def predict(self, image, yaml_file="torchyolo/configs/yolov6/coco.yaml", tracker=False): - predictions, class_names = self.model.predict(source=image, img_size=self.image_size, yaml=yaml_file) - if tracker: - return predictions - else: - object_prediction_list = [] - for *xyxy, conf, cls in reversed(predictions.cpu().detach().numpy()): - x1, y1, x2, y2 = ( - int(xyxy[0]), - int(xyxy[1]), - int(xyxy[2]), - int(xyxy[3]), - ) - bbox = [x1, y1, x2, y2] - score = conf - category_id = int(cls) - category_name = class_names[category_id] - - object_prediction = ObjectPrediction( - bbox=bbox, - category_id=int(category_id), - score=score, - category_name=category_name, - ) - object_prediction_list.append(object_prediction) - - prediction_result = PredictionResult( - object_prediction_list=object_prediction_list, - image=image, - ) - if self.save: - prediction_result.export_visuals(export_dir=self.save_path, file_name=self.output_file_name) - - if self.show: - image = cv2.imread(image) - output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list) - cv2.imshow("Prediction", output_image["image"]) - cv2.waitKey(0) - cv2.destroyAllWindows() - - return prediction_result diff --git a/torchyolo/modelhub/yolov7.py b/torchyolo/modelhub/yolov7.py index 71673e3..05bfdcb 100644 --- a/torchyolo/modelhub/yolov7.py +++ b/torchyolo/modelhub/yolov7.py @@ -1,57 +1,101 @@ import cv2 import yolov7 -from sahi.prediction import ObjectPrediction, PredictionResult -from sahi.utils.cv import visualize_object_predictions +from tqdm import tqdm -from torchyolo.modelhub.basemodel import YoloDetectionModel +from torchyolo.tracker_zoo import load_tracker +from torchyolo.utils.config_utils import get_config +from torchyolo.utils.dataset import LoadData, create_video_writer +from torchyolo.utils.object_vis import video_vis -class Yolov7DetectionModel(YoloDetectionModel): +class Yolov7DetectionModel: + def __init__(self, config_path: str): + self.load_config(config_path) + self.load_model() + + def load_config(self, config_path: str): + self.config_path = config_path + config = get_config(config_path) + self.input_path = config.DATA_CONFIG.INPUT_PATH + self.output_path = config.DATA_CONFIG.OUTPUT_PATH + self.model_type = config.DETECTOR_CONFIG.DETECTOR_TYPE + self.model_path = config.DETECTOR_CONFIG.MODEL_PATH + self.device = config.DETECTOR_CONFIG.DEVICE + self.conf = config.DETECTOR_CONFIG.CONF_TH + self.iou = config.DETECTOR_CONFIG.IOU_TH + self.image_size = config.DETECTOR_CONFIG.IMAGE_SIZE + self.save = config.DATA_CONFIG.SAVE + self.show = config.DATA_CONFIG.SHOW + def load_model(self): - model = yolov7.load(self.model_path, device=self.device, trace=False) - model.conf = self.confidence_threshold - model.iou = self.iou_threshold + model = yolov7.load(self.model_path, device=self.device) + model.conf = self.conf + model.iou = self.iou self.model = model - def predict(self, image, yaml_file=None, tracker=False): - prediction = self.model(image, size=self.image_size) - if tracker: - return prediction - else: - object_prediction_list = [] - for _, image_predictions_in_xyxy_format in enumerate(prediction.xyxy): - for pred in image_predictions_in_xyxy_format.cpu().detach().numpy(): - x1, y1, x2, y2 = ( - int(pred[0]), - int(pred[1]), - int(pred[2]), - int(pred[3]), - ) - bbox = [x1, y1, x2, y2] - score = pred[4] - category_name = self.model.names[int(pred[5])] - category_id = pred[5] + def predict(self, tracker=True): + tracker_module = load_tracker(self.config_path) + config = get_config(self.config_path) + input_path = config.DATA_CONFIG.INPUT_PATH + + tracker_outputs = [None] + dataset = LoadData(input_path) + video_writer = create_video_writer(video_path=input_path, output_path="output") + + for img_src, img_path, vid_cap in tqdm(dataset): + results = self.model(img_src, augment=False) + for image_id, prediction in enumerate(results.pred): + if tracker: + tracker_outputs[image_id] = tracker_module.update(prediction.cpu(), img_src) + for output in tracker_outputs[image_id]: + bbox, track_id, category_id, score = ( + output[:4], + int(output[4]), + output[5], + output[6], + ) + category_name = self.model.names[int(category_id)] + label = f"Id:{track_id} {category_name} {float(score):.2f}" - object_prediction = ObjectPrediction( + if self.save or self.show: + frame = video_vis( + bbox=bbox, + label=label, + frame=img_src, + object_id=int(category_id), + ) + if self.save: + video_writer.write(frame) + + if self.show: + cv2.imshow("frame", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + else: + for pred in prediction.cpu().detach().numpy(): + x1, y1, x2, y2 = ( + int(pred[0]), + int(pred[1]), + int(pred[2]), + int(pred[3]), + ) + bbox = [x1, y1, x2, y2] + score = pred[4] + category_name = self.model.names[int(pred[5])] + category_id = int(pred[5]) + label = f"{category_name} {score:.2f}" + + frame = video_vis( bbox=bbox, - category_id=int(category_id), - score=score, - category_name=category_name, + label=label, + frame=img_src, + object_id=category_id, ) - object_prediction_list.append(object_prediction) - - prediction_result = PredictionResult( - object_prediction_list=object_prediction_list, - image=image, - ) - if self.save: - prediction_result.export_visuals(export_dir=self.save_path, file_name=self.output_file_name) - - if self.show: - image = cv2.imread(image) - output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list) - cv2.imshow("Prediction", output_image["image"]) - cv2.waitKey(0) - cv2.destroyAllWindows() - - return prediction_result + if self.save: + video_writer.write(frame) + + if self.show: + cv2.imshow("frame", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break diff --git a/torchyolo/modelhub/yolov8.py b/torchyolo/modelhub/yolov8.py index cc2eb4b..42f18b9 100644 --- a/torchyolo/modelhub/yolov8.py +++ b/torchyolo/modelhub/yolov8.py @@ -1,56 +1,106 @@ import cv2 -from sahi.prediction import ObjectPrediction, PredictionResult -from sahi.utils.cv import visualize_object_predictions +import torch +from tqdm import tqdm from ultralytics import YOLO -from torchyolo.modelhub.basemodel import YoloDetectionModel +from torchyolo.tracker_zoo import load_tracker +from torchyolo.utils.config_utils import get_config +from torchyolo.utils.dataset import LoadData, create_video_writer +from torchyolo.utils.object_vis import video_vis -class Yolov8DetectionModel(YoloDetectionModel): +class Yolov8DetectionModel: + def __init__(self, config_path: str): + self.load_config(config_path) + self.load_model() + + def load_config(self, config_path: str): + self.config_path = config_path + config = get_config(config_path) + self.input_path = config.DATA_CONFIG.INPUT_PATH + self.output_path = config.DATA_CONFIG.OUTPUT_PATH + self.model_type = config.DETECTOR_CONFIG.DETECTOR_TYPE + self.model_path = config.DETECTOR_CONFIG.MODEL_PATH + self.device = config.DETECTOR_CONFIG.DEVICE + self.conf = config.DETECTOR_CONFIG.CONF_TH + self.iou = config.DETECTOR_CONFIG.IOU_TH + self.image_size = config.DETECTOR_CONFIG.IMAGE_SIZE + self.save = config.DATA_CONFIG.SAVE + self.show = config.DATA_CONFIG.SHOW + def load_model(self): model = YOLO(self.model_path) - model.conf = self.confidence_threshold - model.iou = self.iou_threshold + model.conf = self.conf + model.iou = self.iou self.model = model - def predict(self, image, yaml_file=None, tracker=False): - prediction = self.model(image, imgsz=self.image_size) - if tracker: - return prediction - else: - object_prediction_list = [] - for _, image_predictions_in_xyxy_format in enumerate(prediction): - for pred in image_predictions_in_xyxy_format.cpu().detach().numpy(): - x1, y1, x2, y2 = ( - int(pred[0]), - int(pred[1]), - int(pred[2]), - int(pred[3]), - ) - bbox = [x1, y1, x2, y2] - score = pred[4] - category_name = self.model.model.names[int(pred[5])] - category_id = pred[5] - object_prediction = ObjectPrediction( + def predict(self, tracker=True): + tracker_module = load_tracker(self.config_path) + config = get_config(self.config_path) + input_path = config.DATA_CONFIG.INPUT_PATH + + tracker_outputs = [None] + dataset = LoadData(input_path) + video_writer = create_video_writer(video_path=input_path, output_path="output") + + for img_src, img_path, vid_cap in tqdm(dataset): + results = self.model.predict(img_src, imgsz=self.image_size) + for image_id, prediction in enumerate(results): + if tracker: + boxes = prediction[:].boxes.xyxy + score = prediction[:].boxes.conf + category_id = prediction[:].boxes.cls + dets = torch.cat((boxes, score.unsqueeze(1), category_id.unsqueeze(1)), dim=1) + tracker_outputs[image_id] = tracker_module.update(dets.cpu(), img_src) + for output in tracker_outputs[image_id]: + bbox, track_id, category_id, score = ( + output[:4], + int(output[4]), + output[5], + output[6], + ) + category_name = self.model.model.names[int(category_id)] + label = f"Id:{track_id} {category_name} {float(score):.2f}" + + if self.save or self.show: + frame = video_vis( + bbox=bbox, + label=label, + frame=img_src, + object_id=int(category_id), + ) + if self.save: + video_writer.write(frame) + + if self.show: + cv2.imshow("frame", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break + + else: + for pred in prediction.cpu().detach().numpy(): + x1, y1, x2, y2 = ( + int(pred[0]), + int(pred[1]), + int(pred[2]), + int(pred[3]), + ) + bbox = [x1, y1, x2, y2] + score = pred[4] + category_name = self.model.names[int(pred[5])] + category_id = int(pred[5]) + label = f"{category_name} {score:.2f}" + + frame = video_vis( bbox=bbox, - category_id=int(category_id), - score=score, - category_name=category_name, + label=label, + frame=img_src, + object_id=category_id, ) - object_prediction_list.append(object_prediction) - - prediction_result = PredictionResult( - object_prediction_list=object_prediction_list, - image=image, - ) - if self.save: - prediction_result.export_visuals(export_dir=self.save_path, file_name=self.output_file_name) - - if self.show: - image = cv2.imread(image) - output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list) - cv2.imshow("Prediction", output_image["image"]) - cv2.waitKey(0) - cv2.destroyAllWindows() - - return prediction_result + if self.save: + video_writer.write(frame) + + if self.show: + cv2.imshow("frame", frame) + if cv2.waitKey(1) & 0xFF == ord("q"): + break diff --git a/torchyolo/modelhub/yolox.py b/torchyolo/modelhub/yolox.py deleted file mode 100644 index 37942e2..0000000 --- a/torchyolo/modelhub/yolox.py +++ /dev/null @@ -1,61 +0,0 @@ -import cv2 -from sahi.prediction import ObjectPrediction, PredictionResult -from sahi.utils.cv import visualize_object_predictions -from yoloxdetect import YoloxDetector - -from torchyolo.modelhub.basemodel import YoloDetectionModel - - -class YoloxDetectionModel(YoloDetectionModel): - def load_model(self): - model = YoloxDetector(self.model_path, config_path=self.config_path, device=self.device, hf_model=False) - model.torchyolo = True - model.conf = self.confidence_threshold - model.iou = self.iou_threshold - model.save = self.save - model.show = self.show - self.model = model - - def predict(self, image, yaml_file=None, tracker=False): - object_prediction_list = [] - predict_list = self.model.predict(image_path=image, image_size=self.image_size) - if tracker: - return predict_list - else: - boxes, scores, cls_ids, class_names = predict_list[0], predict_list[1], predict_list[2], predict_list[3] - for i in range(len(boxes)): - box = boxes[i] - category_id = int(cls_ids[i]) - score = scores[i] - if score < self.confidence_threshold: - continue - x0 = int(box[0]) - y0 = int(box[1]) - x1 = int(box[2]) - y1 = int(box[3]) - bbox = [x0, y0, x1, y1] - category_name = class_names[category_id] - - object_prediction = ObjectPrediction( - bbox=bbox, - score=score, - category_id=category_id, - category_name=category_name, - ) - object_prediction_list.append(object_prediction) - - prediction_result = PredictionResult( - object_prediction_list=object_prediction_list, - image=image, - ) - if self.save: - prediction_result.export_visuals(export_dir=self.save_path, file_name=self.output_file_name) - - if self.show: - image = cv2.imread(image) - output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list) - cv2.imshow("Prediction", output_image["image"]) - cv2.waitKey(0) - cv2.destroyAllWindows() - - return prediction_result diff --git a/torchyolo/predict.py b/torchyolo/predict.py index e318a94..ec5f497 100644 --- a/torchyolo/predict.py +++ b/torchyolo/predict.py @@ -1,38 +1,31 @@ -from typing import Optional - from torchyolo.automodel import AutoDetectionModel +from torchyolo.utils.config_utils import get_config class YoloHub: - def __init__( - self, - model_type: str = "yolov5", - model_path: str = "yolov5s.pt", - device: str = "cpu", - image_size: int = 640, - config_path: Optional[str] = "configs.yolox.yolox_s", - ): - self.model_type = model_type - self.model_path = model_path + def __init__(self, config_path: str): + self.load_config(config_path) + + def load_config(self, config_path: str): self.config_path = config_path - self.device = device - self.conf = 0.45 - self.iou = 0.45 - self.image_size = image_size - self.model = None + config = get_config(config_path) + self.input_path = config.DATA_CONFIG.INPUT_PATH + self.output_path = config.DATA_CONFIG.OUTPUT_PATH + self.model_type = config.DETECTOR_CONFIG.DETECTOR_TYPE + self.model_path = config.DETECTOR_CONFIG.MODEL_PATH + self.device = config.DETECTOR_CONFIG.DEVICE + self.conf = config.DETECTOR_CONFIG.CONF_TH + self.iou = config.DETECTOR_CONFIG.IOU_TH + self.image_size = config.DETECTOR_CONFIG.IMAGE_SIZE + self.save = config.DATA_CONFIG.SAVE + self.show = config.DATA_CONFIG.SHOW # Load Model self.load_model() def load_model(self): model = AutoDetectionModel.from_pretrained( - model_type=self.model_type, - model_path=self.model_path, config_path=self.config_path, - device=self.device, - confidence_threshold=self.conf, - iou_threshold=self.iou, - image_size=self.image_size, ) self.model = model return model @@ -62,10 +55,10 @@ def view_model_architecture(self, file_format: str = "pdf"): model_graph.visual_graph.render(format=file_format) return model_graph - def predict(self, image, yaml_file=None, save=False, show=False): - self.model.predict(image, yaml_file, save, show) + def predict(self, tracker: bool = False): + return self.model.predict(tracker) if __name__ == "__main__": - model = YoloHub(model_type="yolov5", model_path="yolov5n.pt", device="cuda:0", image_size=640) - result = model.predict("../test.mp4", save=True, show=False) + model = YoloHub(config_path="torchyolo/default_config.yaml") + result = model.predict(tracker=True) diff --git a/torchyolo/tracker_zoo.py b/torchyolo/tracker_zoo.py new file mode 100644 index 0000000..5f3da90 --- /dev/null +++ b/torchyolo/tracker_zoo.py @@ -0,0 +1,141 @@ +from pathlib import Path +from typing import Optional + +from torchyolo.utils.config_utils import get_config + +DEFAULT_BYTETRACK_CONFIG_PATH = "trackerhub/configs/byte_track.yaml" +DEFAULT_OCSORT_CONFIG_PATH = "trackerhub/configs/oc_sort.yaml" +DEFAULT_NORFAIR_CONFIG_PATH = "trackerhub/configs/norfair_track.yaml" +DEFAULT_STRONGSORT_CONFIG_PATH = "trackerhub/configs/strong_sort.yaml" +DEFAULT_SORT_CONFIG_PATH = "trackerhub/configs/sort_track.yaml" + + +def create_tracker( + tracker_type, + tracker_config_path, + tracker_weight_path: Optional[str] = None, + device: Optional[str] = "cpu", + half: Optional[bool] = False, + conf_th: Optional[str] = 0.05, + iou_th: Optional[str] = 0.05, +) -> object: + if tracker_type == "OC_SORT": + from ocsort.ocsort import OCSort + + if tracker_config_path is None: + config_path = DEFAULT_OCSORT_CONFIG_PATH + else: + config_path = tracker_config_path + + config = get_config(config_path) + oc_sort = OCSort( + det_thresh=conf_th, + max_age=config.OC_SORT.MAX_AGE, + min_hits=config.OC_SORT.MIN_HITS, + iou_threshold=iou_th, + delta_t=config.OC_SORT.DELTA_T, + asso_func=config.OC_SORT.ASSO_FUNC, + inertia=config.OC_SORT.INERTIA, + use_byte=config.OC_SORT.USE_BYTE, + ) + return oc_sort + + elif tracker_type == "BYTE_TRACK": + from bytetracker.byte_tracker import BYTETracker + + if tracker_config_path is None: + config_path = DEFAULT_BYTETRACK_CONFIG_PATH + else: + config_path = tracker_config_path + + config = get_config(config_path) + + byte_tracker = BYTETracker( + track_thresh=conf_th, + track_buffer=config.BYTE_TRACK.TRACK_BUFFER, + frame_rate=config.BYTE_TRACK.FRAME_RATE, + ) + return byte_tracker + + elif tracker_type == "NORFAIR_TRACK": + from norfair_tracker.norfair import NorFairTracker + + if tracker_config_path is None: + config_path = DEFAULT_NORFAIR_CONFIG_PATH + else: + config_path = tracker_config_path + + config = get_config(config_path) + norfair_tracker = NorFairTracker( + distance_function=config.NORFAIR_TRACK.DISTANCE_FUNCTION, + distance_threshold=config.NORFAIR_TRACK.DISTANCE_THRESHOLD, + hit_counter_max=config.NORFAIR_TRACK.HIT_COUNTER_MAX, + initialization_delay=config.NORFAIR_TRACK.INITIALIZATION_DELAY, + pointwise_hit_counter_max=config.NORFAIR_TRACK.POINTWISE_HIT_COUNTER_MAX, + detection_threshold=config.NORFAIR_TRACK.DETECTION_THRESHOLD, + past_detections_length=config.NORFAIR_TRACK.PAST_DETECTIONS_LENGTH, + reid_distance_threshold=config.NORFAIR_TRACK.REID_DISTANCE_THRESHOLD, + reid_hit_counter_max=config.NORFAIR_TRACK.REID_HIT_COUNTER_MAX, + ) + return norfair_tracker + + elif tracker_type == "SORT_TRACK": + from sort.tracker import SortTracker + + if tracker_config_path is None: + config_path = DEFAULT_SORT_CONFIG_PATH + else: + config_path = tracker_config_path + + config = get_config(config_path) + sort_tracker = SortTracker( + max_age=config.SORT_TRACK.MAX_AGE, + min_hits=config.SORT_TRACK.MIN_HITS, + iou_threshold=iou_th, + ) + return sort_tracker + + elif tracker_type == "STRONG_SORT": + from strongsort.strong_sort import StrongSORT + + if tracker_config_path is None: + config_path = DEFAULT_STRONGSORT_CONFIG_PATH + else: + config_path = tracker_config_path + + config = get_config(config_path) + strong_sort = StrongSORT( + tracker_weight_path, + device, + half, + max_dist=config.STRONG_SORT.MAX_DIST, + max_iou_distance=config.STRONG_SORT.MAX_IOU_DISTANCE, + max_age=config.STRONG_SORT.MAX_AGE, + n_init=config.STRONG_SORT.N_INIT, + nn_budget=config.STRONG_SORT.NN_BUDGET, + mc_lambda=config.STRONG_SORT.MC_LAMBDA, + ema_alpha=config.STRONG_SORT.EMA_ALPHA, + ) + return strong_sort + + else: + raise ValueError(f"No such tracker: {tracker_type}") + + +def load_tracker(config_path: str) -> object: + """ + This function is used to track objects in a video using yolov5 and strong sort. + Args: + video_path: video path(str) + """ + config = get_config(config_path) + tracker_module = create_tracker( + tracker_type=config.TRACKER_CONFIG.TRACKER_TYPE, + tracker_weight_path=Path(config.TRACKER_CONFIG.WEIGHT_PATH), + tracker_config_path=config.TRACKER_CONFIG.CONFIG_PATH, + device=config.DETECTOR_CONFIG.DEVICE, + half=config.DETECTOR_CONFIG.HALF, + conf_th=config.DETECTOR_CONFIG.CONF_TH, + iou_th=config.DETECTOR_CONFIG.IOU_TH, + ) + return tracker_module diff --git a/torchyolo/utils/config_utils.py b/torchyolo/utils/config_utils.py index 1b4c329..5319158 100644 --- a/torchyolo/utils/config_utils.py +++ b/torchyolo/utils/config_utils.py @@ -41,3 +41,32 @@ def get_config(config_file: str = None) -> YamlParser: config = YamlParser(config_file=config_file) config.merge_from_file(config_file) return config + + +def load_config(config_path: str): + config_path = config_path + config = get_config(config_path) + input_path = config.DATA_CONFIG.INPUT_PATH + output_path = config.DATA_CONFIG.OUTPUT_PATH + model_type = config.DETECTOR_CONFIG.DETECTOR_TYPE + model_path = config.DETECTOR_CONFIG.MODEL_PATH + device = config.DETECTOR_CONFIG.DEVICE + conf = config.DETECTOR_CONFIG.CONF_TH + iou = config.DETECTOR_CONFIG.IOU_TH + image_size = config.DETECTOR_CONFIG.IMAGE_SIZE + save = config.DATA_CONFIG.SAVE + show = config.DATA_CONFIG.SHOW + + return { + "config_path": config_path, + "input_path": input_path, + "output_path": output_path, + "model_type": model_type, + "model_path": model_path, + "device": device, + "conf": conf, + "iou": iou, + "image_size": image_size, + "save": save, + "show": show, + }