Skip to content

Commit

Permalink
Merge pull request #21 from kadirnar/yolox
Browse files Browse the repository at this point in the history
Updated the Yolox visualization module.
  • Loading branch information
kadirnar authored Jan 6, 2023
2 parents 5ca82ab + 820857f commit ef8ab53
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 7 deletions.
2 changes: 1 addition & 1 deletion torchyolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from torchyolo.predict import YoloPredictor

__version__ = "0.1.4"
__version__ = "0.1.5"
44 changes: 43 additions & 1 deletion torchyolo/modelhub/yolox.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import cv2
from sahi.prediction import ObjectPrediction, PredictionResult
from sahi.utils.cv import visualize_object_predictions
from yoloxdetect import YoloxDetector

from torchyolo.modelhub.basemodel import YoloDetectionModel
Expand All @@ -6,11 +9,50 @@
class YoloxDetectionModel(YoloDetectionModel):
def load_model(self):
model = YoloxDetector(self.model_path, config_path=self.config_path, device=self.device, hf_model=False)
model.torchyolo = True
model.conf = self.confidence_threshold
model.iou = self.iou_threshold
model.save = self.save
model.show = self.show
self.model = model

def predict(self, image, yaml_file=None):
self.model.predict(image_path=image, image_size=self.image_size)
object_prediction_list = []
predict_list = self.model.predict(image_path=image, image_size=self.image_size)
boxes, scores, cls_ids, class_names = predict_list[0], predict_list[1], predict_list[2], predict_list[3]
for i in range(len(boxes)):
box = boxes[i]
category_id = int(cls_ids[i])
score = scores[i]
if score < self.confidence_threshold:
continue
x0 = int(box[0])
y0 = int(box[1])
x1 = int(box[2])
y1 = int(box[3])
bbox = [x0, y0, x1, y1]
category_name = class_names[category_id]

object_prediction = ObjectPrediction(
bbox=bbox,
score=score,
category_id=category_id,
category_name=category_name,
)
object_prediction_list.append(object_prediction)

prediction_result = PredictionResult(
object_prediction_list=object_prediction_list,
image=image,
)
if self.save:
prediction_result.export_visuals(export_dir=self.save_path, file_name=self.output_file_name)

if self.show:
image = cv2.imread(image)
output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list)
cv2.imshow("Prediction", output_image["image"])
cv2.waitKey(0)
cv2.destroyAllWindows()

return prediction_result
22 changes: 17 additions & 5 deletions torchyolo/predict.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
from typing import Optional

from torchyolo.automodel import AutoDetectionModel


class YoloPredictor:
def __init__(self, model_type="yolov5", model_path="yolov5s.pt", device="cpu", image_size=640):
def __init__(
self,
model_type: str = "yolov5",
model_path: str = "yolov5s.pt",
device: str = "cpu",
image_size: int = 640,
config_path: Optional[str] = "configs.yolox.yolox_s",
):
self.model_type = model_type
self.model_path = model_path
self.config_path = "torchyolo.configs.yolox.yolox_m" # yolox_nano.py
self.config_path = config_path
self.device = device
self.conf = 0.05
self.iou = 0.05
self.conf = 0.45
self.iou = 0.45
self.image_size = image_size
self.save = True
self.show = False
Expand All @@ -35,6 +44,9 @@ def predict(self, image, yaml_file=None):


if __name__ == "__main__":
predictor = YoloPredictor(model_type="yolov5", model_path="yolov5n.pt", device="cuda:0", image_size=640)
predictor = YoloPredictor(
model_type="yolox", model_path="yolox_s.pth", config_path="configs.yolox.yolox_s", device="cpu", image_size=640
)

image = "data/highway.jpg"
result = predictor.predict(image)

0 comments on commit ef8ab53

Please sign in to comment.