From 00ec7143dbe3744f32107f126f2da4dff40229f6 Mon Sep 17 00:00:00 2001 From: HenryZhuHR <296506195@qq.com> Date: Sat, 8 Jun 2024 22:01:57 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=91=20hotfix:=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E6=8E=A8=E7=90=86=E4=BB=A3=E7=A0=81=E7=9A=84=E5=8F=82=E6=95=B0?= =?UTF-8?q?=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- infer-track.py | 73 +++++++++++++++++++++++++++++++++++--------------- infer.py | 52 ++++++++++++++++++++++------------- 2 files changed, 84 insertions(+), 41 deletions(-) diff --git a/infer-track.py b/infer-track.py index 12b61ef..5489ab0 100644 --- a/infer-track.py +++ b/infer-track.py @@ -14,7 +14,9 @@ class TTrackBbox: - def __init__(self, tlwh: np.ndarray, objid: int, score: np.float32, clsid: int) -> None: + def __init__( + self, tlwh: np.ndarray, objid: int, score: np.float32, clsid: int + ) -> None: self.tlwh = tlwh self.objid = objid self.score = score @@ -22,12 +24,35 @@ def __init__(self, tlwh: np.ndarray, objid: int, score: np.float32, clsid: int) class TrackArgs: + + @staticmethod + def get_args(): + parser = argparse.ArgumentParser() + # fmt: off + parser.add_argument("-v", "--video", type=str, default=".cache/palace.mp4") + parser.add_argument("-o", "--outdir", type=str, default="tmp") + parser.add_argument("-m", "--model", type=str, default=".cache/yolov5/yolov5s_openvino_model/yolov5s.xml") + parser.add_argument("-c", "--config", type=str, default=".cache/yolov5/coco.yaml") + parser.add_argument("-s", "--img-size", nargs="+", type=int, default=[640, 640]) + parser.add_argument("--aspect_ratio_thresh", type=float, default=1.6, + help="threshold for filtering out boxes of which aspect ratio are above the given value." ) + parser.add_argument("--min_box_area", type=float, default=10, help="filter out tiny boxes") + # fmt: on + return parser.parse_args() + def __init__(self) -> None: args = self.get_args() self.video: str = args.video self.output_dir: str = args.outdir self.model_path: str = args.model + self.config: str = args.config + if len(args.img_size) == 2: + self.img_size: List[int] = args.img_size + elif len(args.img_size) == 1: + self.img_size: List[int] = [args.img_size, args.img_size] + else: + raise ValueError("Invalid img_size") self.aspect_ratio_thresh: float = args.aspect_ratio_thresh self.min_box_area: float = args.min_box_area @@ -40,19 +65,6 @@ def __init__(self) -> None: if not os.path.exists(self.model_path): raise FileNotFoundError(f"Model file {self.model_path} not found") - @staticmethod - def get_args(): - parser = argparse.ArgumentParser() - # fmt: off - parser.add_argument("-v", "--video", type=str, default=".cache/palace.mp4") - parser.add_argument("-o", "--outdir", type=str, default="tmp") - parser.add_argument("-m", "--model", type=str, default=".cache/yolov5/yolov5s_openvino_model/yolov5s.xml") - parser.add_argument("--aspect_ratio_thresh", type=float, default=1.6, - help="threshold for filtering out boxes of which aspect ratio are above the given value." ) - parser.add_argument("--min_box_area", type=float, default=10, help="filter out tiny boxes") - # fmt: on - return parser.parse_args() - def main(): args = TrackArgs() @@ -72,15 +84,18 @@ def main(): print("-- Available devices:", detector.query_device()) detector.load_model(args.model_path, verbose=True) - with open(".cache/yolov5/yolov5s_openvino_model/yolov5s.yaml", "r") as f: - label_map: Dict[int, str] = yaml.load(f, Loader=yaml.FullLoader)["names"] + with open(args.config, "r") as f: + label_map: Dict[int, str] = yaml.load(f, Loader=yaml.FullLoader)[ + "names" + ] label_list = list(label_map.values()) print(label_list) tracker = ByteTracker() # (1, 3, 640, 640) - dummy_inputs = np.random.randn(1, 3, 640, 640).astype(np.float32) + img_size = args.img_size + dummy_inputs = np.random.randn(1, 3, *img_size).astype(np.float32) output_t = detector.infer(dummy_inputs) frame_id = 0 @@ -99,7 +114,7 @@ def main(): copy_time = get_consume_t_ms(start_time) start_time = time.time() - input_t, scale_h, scale_w = Process.preprocess(img) + input_t, scale_h, scale_w = Process.preprocess(img, img_size) preprocess_time = get_consume_t_ms(start_time) start_time = time.time() @@ -107,7 +122,9 @@ def main(): infer_time = get_consume_t_ms(start_time) start_time = time.time() - preds = Process.postprocess(output_t) # [ B, [x1, y1, x2, y2, conf, cls] ] + preds = Process.postprocess( + output_t + ) # [ B, [x1, y1, x2, y2, conf, cls] ] online_targets = tracker.update(preds, scale_h, scale_w) online_tackbboxes: List[TTrackBbox] = [] for t in online_targets: @@ -181,15 +198,27 @@ def plot_tracking( line_thickness = 2 radius = max(5, int(img_w / 140.0)) - cv2.putText(img, display_info, (0, int(15 * text_scale)), cv2.FONT_HERSHEY_PLAIN, 2, (0, 0, 255), thickness=2) + cv2.putText( + img, + display_info, + (0, int(15 * text_scale)), + cv2.FONT_HERSHEY_PLAIN, + 2, + (0, 0, 255), + thickness=2, + ) for i, tackbbox in enumerate(tackbboxes): x1, y1, w, h = tackbbox.tlwh intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h))) obj_id = int(tackbbox.objid) - id_text = f"{label_list[tackbbox.clsid]}/{int(obj_id)} ({tackbbox.score:.2f})" + id_text = ( + f"{label_list[tackbbox.clsid]}/{int(obj_id)} ({tackbbox.score:.2f})" + ) color = get_color(abs(obj_id)) - cv2.rectangle(img, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness) + cv2.rectangle( + img, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness + ) cv2.putText( img, id_text, diff --git a/infer.py b/infer.py index a1d0715..21137fe 100644 --- a/infer.py +++ b/infer.py @@ -1,6 +1,6 @@ import argparse import os -from typing import Dict +from typing import Dict, List import cv2 import tqdm import yaml @@ -9,18 +9,33 @@ from dlinfer.detector import Process -def parse_args() -> argparse.Namespace: - """Parse and return command line arguments.""" - parser = argparse.ArgumentParser(add_help=False) - args = parser.add_argument_group("Options") - args.add_argument("--model", type=str, default=".cache/yolov5/yolov5s.onnx") - args.add_argument("-i", "--input", type=str, default="images/bus.jpg") - return parser.parse_args() +class InferArgs: + @staticmethod + def get_args(): + parser = argparse.ArgumentParser(add_help=False) + # fmt: off + parser.add_argument("-m", "--model", type=str, default=".cache/yolov5/yolov5s.onnx") + parser.add_argument("-c", "--config", type=str, default=".cache/yolov5/coco.yaml") + parser.add_argument("-s", "--img-size", nargs="+", type=int, default=[640, 640]) + parser.add_argument("-i", "--input", type=str, default="images/bus.jpg") + # fmt: on + return parser.parse_args() + def __init__(self) -> None: + args = self.get_args() + self.model: str = args.model + self.config: str = args.config + if len(args.img_size) == 2: + self.img_size: List[int] = args.img_size + elif len(args.img_size) == 1: + self.img_size: List[int] = [args.img_size, args.img_size] + else: + raise ValueError("Invalid img_size") + self.input: str = args.input -def main() -> int: - args = parse_args() +def main() -> int: + args = InferArgs() backends = DetectorInferBackends() # =============== Choose backend to Infer =============== # ------ Choose one and comment out the others ------ @@ -40,29 +55,28 @@ def main() -> int: detector.load_model(args.model, verbose=True) - with open(".cache/yolov5/coco.yaml", "r") as f: - label_map: Dict[int, str] = yaml.load(f, Loader=yaml.FullLoader)[ - "names" - ] - label_list = list(label_map.values()) - # print(label_list) + with open(args.config, "r") as f: + file_content = yaml.load(f, Loader=yaml.FullLoader) + label_map: Dict[int, str] = file_content["names"] + label_list = list(label_map.values()) img = cv2.imread(args.input) # H W C os.makedirs("tmp", exist_ok=True) + img_size = args.img_size # -- warm up - input_t, scale_h, scale_w = Process.preprocess(img) # B C H W + input_t, scale_h, scale_w = Process.preprocess(img, img_size) # B C H W output_t = detector.infer(input_t) # -- do inference print("-- do inference") - pbar = tqdm.tqdm(range(100)) + pbar = tqdm.tqdm(range(10)) total_sum_time = 0 _cnt = 0 for i in pbar: start_time = cv2.getTickCount() # -- preprocess - input_t, scale_h, scale_w = Process.preprocess(img) # B C H W + input_t, scale_h, scale_w = Process.preprocess(img, img_size) # B C H W # -- inference output_t = detector.infer(input_t) end_time = cv2.getTickCount()