Skip to content

Commit

Permalink
refactored code
Browse files Browse the repository at this point in the history
  • Loading branch information
OneMagicKey committed May 28, 2024
1 parent d081ff0 commit 635fb36
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 26 deletions.
1 change: 0 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
!src/utils.py

# Install files
!src/images/bus.jpg
!requirements.txt

**/__pycache__
44 changes: 25 additions & 19 deletions src/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,24 +109,25 @@ def print_results(

return result_list

def version_handler(self, output: np.ndarray, nc: int) -> tuple[np.ndarray, int]:
def version_handler(self, output: np.ndarray, nc: int) -> np.ndarray:
"""
Convert output of the model to a single format.
:param output: an array with boxes from the model's output
:param nc: number of classes
:return: output with shape (num_boxes, probs_start_idx+num_classes+num_masks) and probs_start_idx
:return: output with shape (num_objects, 4+num_classes+num_masks)
"""
if self.version == 8:
probs_start_idx = 4
output = output.transpose((0, 2, 1)).squeeze(0)
match self.version:
case 8:
output = output.transpose((0, 2, 1)).squeeze(0)
case 5:
output = output[output[..., 4] > self.conf]
output[..., 5 : 5 + nc] *= output[..., 4:5] # conf = obj_conf * cls_conf # fmt: skip
output = np.delete(output, 4, axis=1) # remove obj_conf to preserve the format # fmt: skip
case _:
output = output.squeeze(0)

else: # v5
probs_start_idx = 5
output = output[output[..., 4] > self.conf]
output[..., probs_start_idx : probs_start_idx + nc] *= output[..., 4:5] # conf = obj_conf * cls_conf # fmt: skip

return output, probs_start_idx
return output

def get_colors(
self, classes: np.ndarray, color_scheme: Literal["equal", "random"] = "equal"
Expand Down Expand Up @@ -155,26 +156,31 @@ def nms(
return_masks: bool = False,
) -> tuple[np.ndarray, ...]:
"""
Apply non-maximum suppression to the output of the model.
Post-processes the output from a YOLOv5/YOLOv8 model to extract classes,
confidences, bounding boxes, and optionally, mask coefficients.
This function processes the raw output from the model, filters detections using
non-maximum suppression, and adjusts the bounding box coordinates based on the
provided ratio and padding values.
:return: classes, confidences, boxes and (optional) mask coefficients
"""
nc = len(self.labels_name["en"])
output, probs_id = self.version_handler(output, nc)
output = self.version_handler(output, nc)

classes = output[..., probs_id : probs_id + nc].argmax(axis=-1)
boxes = xywh2box(output[..., :4], ratio, padw=pad[0], padh=pad[1])
confs = output[..., probs_id : probs_id + nc]
boxes, confs, masks_coefs = np.split(output, [4, 4 + nc], axis=1)
classes = confs.argmax(axis=-1)
confs = confs[np.arange(classes.shape[-1]), classes]
boxes = xywh2box(boxes, ratio, padw=pad[0], padh=pad[1])

indices = cv2.dnn.NMSBoxes(boxes, confs, self.conf, self.iou)
indices = list(indices)
classes, confs, boxes = classes[indices], confs[indices], boxes[indices]

if return_masks:
masks = output[indices, probs_id + nc :]
return classes[indices], confs[indices], xywh2xyxy(boxes[indices]), masks
return classes, confs, xywh2xyxy(boxes), masks_coefs[indices]

return classes[indices], confs[indices], xywh2xyxy(boxes[indices])
return classes, confs, xywh2xyxy(boxes)

@abstractmethod
def postprocess(self, *args, **kwargs) -> tuple[np.ndarray, ...]:
Expand Down
4 changes: 2 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Literal

import cv2
import numpy as np
from aiogram import types

from model.model import YoloOnnxDetection, YoloOnnxSegmentation
Expand Down Expand Up @@ -43,7 +43,7 @@ def init_models(
}

# Warmup
test_img = cv2.imread("src/images/bus.jpg", cv2.IMREAD_COLOR)
test_img = np.ones((640, 640, 3), dtype=np.uint8)
for model in bot_models.values():
_ = model(test_img, raw=True)

Expand Down
7 changes: 3 additions & 4 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def load_img(path: str = "src/images/zidane.jpg"):

@staticmethod
def create_empty_img():
img = np.zeros((512, 512, 3), dtype=np.uint8)
img.fill(255)
img = np.ones((512, 512, 3), dtype=np.uint8)

return img

Expand Down Expand Up @@ -91,7 +90,7 @@ def test_render_detection(self):

def test_render_empty_segmentation(self):
img, yolo = self.img_empty, self.segmentation_model
path_save_to = "tests/white_seg.jpg"
path_save_to = "tests/empty_seg.jpg"

classes, confs, boxes, masks = yolo(img)
yolo.render(img, classes, confs, boxes, masks, save_path=path_save_to)
Expand All @@ -102,7 +101,7 @@ def test_render_empty_segmentation(self):

def test_render_empty_detection(self):
img, yolo = self.img_empty, self.detection_model
path_save_to = "tests/white_det.jpg"
path_save_to = "tests/empty_det.jpg"

classes, confs, boxes = yolo(img)
yolo.render(img, classes, confs, boxes, save_path=path_save_to)
Expand Down

0 comments on commit 635fb36

Please sign in to comment.