Skip to content

Commit

Permalink
obj det
Browse files Browse the repository at this point in the history
  • Loading branch information
TingquanGao committed Nov 20, 2024
1 parent b1defcf commit 777a143
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 52 deletions.
2 changes: 1 addition & 1 deletion paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Global:
model: DLinear_ad
mode: predict # check_dataset/train/evaluate/predict
mode: check_dataset # check_dataset/train/evaluate/predict
dataset_dir: "/paddle/dataset/paddlex/ts_ad/ts_anomaly_examples/"
device: gpu:0
output: "output"
Expand Down
4 changes: 3 additions & 1 deletion paddlex/inference/models/base/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def _check(self):
cmpt.set_outputs(outputs)
if idx != 0:
for input_key in inputs:
assert input_key in simulation_data
assert (
input_key in simulation_data
), f"{input_key} is needed by {name}, but not found!"
simulation_data.extend(list(outputs.values()))

def __call__(self, data, i=0):
Expand Down
12 changes: 7 additions & 5 deletions paddlex/inference/models/object_detection/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ Components:
ImageDetPredictor:
inputs:
img: img
scale_factors: scale_factors
outputs:
pred: pred
DetPostProcess:
inputs:
boxes: boxes
img_size: img_size
outputs:
boxes: boxes
DetPostProcess:
inputs:
boxes: boxes
img_size: img_size
outputs:
boxes: boxes
91 changes: 46 additions & 45 deletions paddlex/inference/models/object_detection/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,58 +19,59 @@
from ....utils.fonts import PINGFANG_FONT_FILE_PATH
from ...utils.color_map import get_colormap, font_colormap
from ..base import CVResult
from ...common.funcs import draw_box


def draw_box(img, boxes):
"""
Args:
img (PIL.Image.Image): PIL image
boxes (list): a list of dictionaries representing detection box information.
Returns:
img (PIL.Image.Image): visualized image
"""
font_size = int(0.024 * int(img.width)) + 2
font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
# def draw_box(img, boxes):
# """
# Args:
# img (PIL.Image.Image): PIL image
# boxes (list): a list of dictionaries representing detection box information.
# Returns:
# img (PIL.Image.Image): visualized image
# """
# font_size = int(0.024 * int(img.width)) + 2
# font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")

draw_thickness = int(max(img.size) * 0.005)
draw = ImageDraw.Draw(img)
label2color = {}
catid2fontcolor = {}
color_list = get_colormap(rgb=True)
# draw_thickness = int(max(img.size) * 0.005)
# draw = ImageDraw.Draw(img)
# label2color = {}
# catid2fontcolor = {}
# color_list = get_colormap(rgb=True)

for i, dt in enumerate(boxes):
# clsid = dt["cls_id"]
label, bbox, score = dt["label"], dt["coordinate"], dt["score"]
if label not in label2color:
color_index = i % len(color_list)
label2color[label] = color_list[color_index]
catid2fontcolor[label] = font_colormap(color_index)
color = tuple(label2color[label])
font_color = tuple(catid2fontcolor[label])
# for i, dt in enumerate(boxes):
# # clsid = dt["cls_id"]
# label, bbox, score = dt["label"], dt["coordinate"], dt["score"]
# if label not in label2color:
# color_index = i % len(color_list)
# label2color[label] = color_list[color_index]
# catid2fontcolor[label] = font_colormap(color_index)
# color = tuple(label2color[label])
# font_color = tuple(catid2fontcolor[label])

xmin, ymin, xmax, ymax = bbox
# draw bbox
draw.line(
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
width=draw_thickness,
fill=color,
)
# xmin, ymin, xmax, ymax = bbox
# # draw bbox
# draw.line(
# [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), (xmin, ymin)],
# width=draw_thickness,
# fill=color,
# )

# draw label
text = "{} {:.2f}".format(dt["label"], score)
if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
tw, th = draw.textsize(text, font=font)
else:
left, top, right, bottom = draw.textbbox((0, 0), text, font)
tw, th = right - left, bottom - top + 4
if ymin < th:
draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
else:
draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)
# # draw label
# text = "{} {:.2f}".format(dt["label"], score)
# if tuple(map(int, PIL.__version__.split("."))) <= (10, 0, 0):
# tw, th = draw.textsize(text, font=font)
# else:
# left, top, right, bottom = draw.textbbox((0, 0), text, font)
# tw, th = right - left, bottom - top + 4
# if ymin < th:
# draw.rectangle([(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)
# draw.text((xmin + 2, ymin - 2), text, fill=font_color, font=font)
# else:
# draw.rectangle([(xmin, ymin - th), (xmin + tw + 4, ymin + 1)], fill=color)
# draw.text((xmin + 2, ymin - th - 2), text, fill=font_color, font=font)

return img
# return img


class DetResult(CVResult):
Expand Down

0 comments on commit 777a143

Please sign in to comment.