Skip to content

Commit

Permalink
Merge pull request #19 from kadirnar/yolov6
Browse files Browse the repository at this point in the history
Updated the Yollov6 visualization module.
  • Loading branch information
kadirnar authored Jan 6, 2023
2 parents 7677a6d + 84730ce commit 5ca82ab
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 22 deletions.
8 changes: 2 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ predictor.save = True
predictor.show = False
image = "data/highway.jpg"
result = predictor.predict(image)
# Yolov6
result = predictor.predict(image, class_names="coco.names")
```
Note: You only need to make changes in the default_config.yaml file.

# Contributing
Before opening a PR:
Expand All @@ -41,11 +42,6 @@ Before opening a PR:
bash script/code_format.sh
```

### TODO
- [ ] Add more models(YOLOV4, Scaled-YOLOv4, YOLOR)
- [ ] Add Train, Export and Eval scripts
- [ ] Add Benchmark Results

### Acknowledgement
A part of the code is borrowed from [SAHI](https://github.com/obss/sahi). Many thanks for their wonderful works.

Expand Down
2 changes: 1 addition & 1 deletion torchyolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from torchyolo.predict import YoloPredictor

__version__ = "0.1.3"
__version__ = "0.1.4"
2 changes: 1 addition & 1 deletion torchyolo/modelhub/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_model(self):
"""
raise NotImplementedError()

def predict(self, image: np.ndarray):
def predict(self, image: np.ndarray, yaml_file: str = None):
"""
This function should be implemented in a way that detection model
should be initialized and set to self.model.
Expand Down
2 changes: 1 addition & 1 deletion torchyolo/modelhub/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def load_model(self):
model.iou = self.iou_threshold
self.model = model

def predict(self, image):
def predict(self, image, yaml_file=None):
prediction = self.model(image, size=self.image_size)
object_prediction_list = []
for _, image_predictions_in_xyxy_format in enumerate(prediction.xyxy):
Expand Down
48 changes: 44 additions & 4 deletions torchyolo/modelhub/yolov6.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import cv2
from sahi.prediction import ObjectPrediction, PredictionResult
from sahi.utils.cv import visualize_object_predictions
from yolov6 import YOLOV6

from torchyolo.modelhub.basemodel import YoloDetectionModel
Expand All @@ -6,12 +9,49 @@
class Yolov6DetectionModel(YoloDetectionModel):
def load_model(self):
model = YOLOV6(self.model_path, device=self.device)
model.torchyolo = True
model.font_path = "torchyolo/configs/yolov6/Arial.ttf"
model.conf_thres = self.confidence_threshold
model.iou_thresh = self.iou_threshold
model.conf = self.confidence_threshold
model.iou = self.iou_threshold
model.save_img = self.save
model.show_img = self.show
self.model = model

def predict(self, image):
self.model.predict(source=image, img_size=self.image_size, yaml=self.yaml_file)
def predict(self, image, yaml_file="torchyolo/configs/yolov6/coco.yaml"):
object_prediction_list = []
predictions, class_names = self.model.predict(source=image, img_size=self.image_size, yaml=yaml_file)
for *xyxy, conf, cls in reversed(predictions.cpu().detach().numpy()):
x1, y1, x2, y2 = (
int(xyxy[0]),
int(xyxy[1]),
int(xyxy[2]),
int(xyxy[3]),
)
bbox = [x1, y1, x2, y2]
score = conf
category_id = int(cls)
category_name = class_names[category_id]

object_prediction = ObjectPrediction(
bbox=bbox,
category_id=int(category_id),
score=score,
category_name=category_name,
)
object_prediction_list.append(object_prediction)

prediction_result = PredictionResult(
object_prediction_list=object_prediction_list,
image=image,
)
if self.save:
prediction_result.export_visuals(export_dir=self.save_path, file_name=self.output_file_name)

if self.show:
image = cv2.imread(image)
output_image = visualize_object_predictions(image=image, object_prediction_list=object_prediction_list)
cv2.imshow("Prediction", output_image["image"])
cv2.waitKey(0)
cv2.destroyAllWindows()

return prediction_result
2 changes: 1 addition & 1 deletion torchyolo/modelhub/yolov7.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def load_model(self):
model.iou = self.iou_threshold
self.model = model

def predict(self, image):
def predict(self, image, yaml_file=None):
prediction = self.model(image, size=self.image_size)
object_prediction_list = []
for _, image_predictions_in_xyxy_format in enumerate(prediction.xyxy):
Expand Down
2 changes: 1 addition & 1 deletion torchyolo/modelhub/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ def load_model(self):
model.show = self.show
self.model = model

def predict(self, image):
def predict(self, image, yaml_file=None):
self.model.predict(image_path=image, image_size=self.image_size)
14 changes: 7 additions & 7 deletions torchyolo/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ def __init__(self, model_type="yolov5", model_path="yolov5s.pt", device="cpu", i
self.model_path = model_path
self.config_path = "torchyolo.configs.yolox.yolox_m" # yolox_nano.py
self.device = device
self.conf_thres = 0.25
self.iou_thres = 0.45
self.conf = 0.05
self.iou = 0.05
self.image_size = image_size
self.save = True
self.show = False
Expand All @@ -22,19 +22,19 @@ def load_model(self):
model_path=self.model_path,
config_path=self.config_path,
device=self.device,
confidence_threshold=self.conf_thres,
iou_threshold=self.iou_thres,
confidence_threshold=self.conf,
iou_threshold=self.iou,
image_size=self.image_size,
)
model.save = self.save
model.show = self.show
self.model = model

def predict(self, image):
return self.model.predict(image)
def predict(self, image, yaml_file=None):
return self.model.predict(image, yaml_file=yaml_file)


if __name__ == "__main__":
predictor = YoloPredictor(model_type="yolov5", model_path="yolov5s.pt", device="cpu", image_size=640)
predictor = YoloPredictor(model_type="yolov5", model_path="yolov5n.pt", device="cuda:0", image_size=640)
image = "data/highway.jpg"
result = predictor.predict(image)

0 comments on commit 5ca82ab

Please sign in to comment.