Skip to content

Commit

Permalink
🚑 hotfix: 修复推理代码的参数设置
Browse files Browse the repository at this point in the history
  • Loading branch information
henryzhuhr committed Jun 8, 2024
1 parent fba3ca2 commit 00ec714
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 41 deletions.
73 changes: 51 additions & 22 deletions infer-track.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,45 @@


class TTrackBbox:
def __init__(self, tlwh: np.ndarray, objid: int, score: np.float32, clsid: int) -> None:
def __init__(
self, tlwh: np.ndarray, objid: int, score: np.float32, clsid: int
) -> None:
self.tlwh = tlwh
self.objid = objid
self.score = score
self.clsid = clsid


class TrackArgs:

@staticmethod
def get_args():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument("-v", "--video", type=str, default=".cache/palace.mp4")
parser.add_argument("-o", "--outdir", type=str, default="tmp")
parser.add_argument("-m", "--model", type=str, default=".cache/yolov5/yolov5s_openvino_model/yolov5s.xml")
parser.add_argument("-c", "--config", type=str, default=".cache/yolov5/coco.yaml")
parser.add_argument("-s", "--img-size", nargs="+", type=int, default=[640, 640])
parser.add_argument("--aspect_ratio_thresh", type=float, default=1.6,
help="threshold for filtering out boxes of which aspect ratio are above the given value." )
parser.add_argument("--min_box_area", type=float, default=10, help="filter out tiny boxes")
# fmt: on
return parser.parse_args()

def __init__(self) -> None:
args = self.get_args()

self.video: str = args.video
self.output_dir: str = args.outdir
self.model_path: str = args.model
self.config: str = args.config
if len(args.img_size) == 2:
self.img_size: List[int] = args.img_size
elif len(args.img_size) == 1:
self.img_size: List[int] = [args.img_size, args.img_size]
else:
raise ValueError("Invalid img_size")
self.aspect_ratio_thresh: float = args.aspect_ratio_thresh
self.min_box_area: float = args.min_box_area

Expand All @@ -40,19 +65,6 @@ def __init__(self) -> None:
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model file {self.model_path} not found")

@staticmethod
def get_args():
parser = argparse.ArgumentParser()
# fmt: off
parser.add_argument("-v", "--video", type=str, default=".cache/palace.mp4")
parser.add_argument("-o", "--outdir", type=str, default="tmp")
parser.add_argument("-m", "--model", type=str, default=".cache/yolov5/yolov5s_openvino_model/yolov5s.xml")
parser.add_argument("--aspect_ratio_thresh", type=float, default=1.6,
help="threshold for filtering out boxes of which aspect ratio are above the given value." )
parser.add_argument("--min_box_area", type=float, default=10, help="filter out tiny boxes")
# fmt: on
return parser.parse_args()


def main():
args = TrackArgs()
Expand All @@ -72,15 +84,18 @@ def main():
print("-- Available devices:", detector.query_device())
detector.load_model(args.model_path, verbose=True)

with open(".cache/yolov5/yolov5s_openvino_model/yolov5s.yaml", "r") as f:
label_map: Dict[int, str] = yaml.load(f, Loader=yaml.FullLoader)["names"]
with open(args.config, "r") as f:
label_map: Dict[int, str] = yaml.load(f, Loader=yaml.FullLoader)[
"names"
]
label_list = list(label_map.values())
print(label_list)

tracker = ByteTracker()

# (1, 3, 640, 640)
dummy_inputs = np.random.randn(1, 3, 640, 640).astype(np.float32)
img_size = args.img_size
dummy_inputs = np.random.randn(1, 3, *img_size).astype(np.float32)
output_t = detector.infer(dummy_inputs)

frame_id = 0
Expand All @@ -99,15 +114,17 @@ def main():
copy_time = get_consume_t_ms(start_time)
start_time = time.time()

input_t, scale_h, scale_w = Process.preprocess(img)
input_t, scale_h, scale_w = Process.preprocess(img, img_size)
preprocess_time = get_consume_t_ms(start_time)
start_time = time.time()

output_t = detector.infer(input_t)
infer_time = get_consume_t_ms(start_time)
start_time = time.time()

preds = Process.postprocess(output_t) # [ B, [x1, y1, x2, y2, conf, cls] ]
preds = Process.postprocess(
output_t
) # [ B, [x1, y1, x2, y2, conf, cls] ]
online_targets = tracker.update(preds, scale_h, scale_w)
online_tackbboxes: List[TTrackBbox] = []
for t in online_targets:
Expand Down Expand Up @@ -181,15 +198,27 @@ def plot_tracking(
line_thickness = 2
radius = max(5, int(img_w / 140.0))

cv2.putText(img, display_info, (0, int(15 * text_scale)), cv2.FONT_HERSHEY_PLAIN, 2, (0, 0, 255), thickness=2)
cv2.putText(
img,
display_info,
(0, int(15 * text_scale)),
cv2.FONT_HERSHEY_PLAIN,
2,
(0, 0, 255),
thickness=2,
)

for i, tackbbox in enumerate(tackbboxes):
x1, y1, w, h = tackbbox.tlwh
intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
obj_id = int(tackbbox.objid)
id_text = f"{label_list[tackbbox.clsid]}/{int(obj_id)} ({tackbbox.score:.2f})"
id_text = (
f"{label_list[tackbbox.clsid]}/{int(obj_id)} ({tackbbox.score:.2f})"
)
color = get_color(abs(obj_id))
cv2.rectangle(img, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
cv2.rectangle(
img, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness
)
cv2.putText(
img,
id_text,
Expand Down
52 changes: 33 additions & 19 deletions infer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import os
from typing import Dict
from typing import Dict, List
import cv2
import tqdm
import yaml
Expand All @@ -9,18 +9,33 @@
from dlinfer.detector import Process


def parse_args() -> argparse.Namespace:
"""Parse and return command line arguments."""
parser = argparse.ArgumentParser(add_help=False)
args = parser.add_argument_group("Options")
args.add_argument("--model", type=str, default=".cache/yolov5/yolov5s.onnx")
args.add_argument("-i", "--input", type=str, default="images/bus.jpg")
return parser.parse_args()
class InferArgs:
@staticmethod
def get_args():
parser = argparse.ArgumentParser(add_help=False)
# fmt: off
parser.add_argument("-m", "--model", type=str, default=".cache/yolov5/yolov5s.onnx")
parser.add_argument("-c", "--config", type=str, default=".cache/yolov5/coco.yaml")
parser.add_argument("-s", "--img-size", nargs="+", type=int, default=[640, 640])
parser.add_argument("-i", "--input", type=str, default="images/bus.jpg")
# fmt: on
return parser.parse_args()

def __init__(self) -> None:
args = self.get_args()
self.model: str = args.model
self.config: str = args.config
if len(args.img_size) == 2:
self.img_size: List[int] = args.img_size
elif len(args.img_size) == 1:
self.img_size: List[int] = [args.img_size, args.img_size]
else:
raise ValueError("Invalid img_size")
self.input: str = args.input

def main() -> int:
args = parse_args()

def main() -> int:
args = InferArgs()
backends = DetectorInferBackends()
# =============== Choose backend to Infer ===============
# ------ Choose one and comment out the others ------
Expand All @@ -40,29 +55,28 @@ def main() -> int:

detector.load_model(args.model, verbose=True)

with open(".cache/yolov5/coco.yaml", "r") as f:
label_map: Dict[int, str] = yaml.load(f, Loader=yaml.FullLoader)[
"names"
]
label_list = list(label_map.values())
# print(label_list)
with open(args.config, "r") as f:
file_content = yaml.load(f, Loader=yaml.FullLoader)
label_map: Dict[int, str] = file_content["names"]
label_list = list(label_map.values())

img = cv2.imread(args.input) # H W C
os.makedirs("tmp", exist_ok=True)

img_size = args.img_size
# -- warm up
input_t, scale_h, scale_w = Process.preprocess(img) # B C H W
input_t, scale_h, scale_w = Process.preprocess(img, img_size) # B C H W
output_t = detector.infer(input_t)

# -- do inference
print("-- do inference")
pbar = tqdm.tqdm(range(100))
pbar = tqdm.tqdm(range(10))
total_sum_time = 0
_cnt = 0
for i in pbar:
start_time = cv2.getTickCount()
# -- preprocess
input_t, scale_h, scale_w = Process.preprocess(img) # B C H W
input_t, scale_h, scale_w = Process.preprocess(img, img_size) # B C H W
# -- inference
output_t = detector.infer(input_t)
end_time = cv2.getTickCount()
Expand Down

0 comments on commit 00ec714

Please sign in to comment.