From b608d80db1b5f3ecaf0ed63615264732b1006a58 Mon Sep 17 00:00:00 2001 From: kshitijrajsharma Date: Mon, 25 Nov 2024 16:22:58 +0000 Subject: [PATCH] fix(onnx): supports onnx output of the yolo models for the inference --- hot_fair_utilities/training/yolo_v8_v1/train.py | 4 +++- hot_fair_utilities/training/yolo_v8_v2/train.py | 4 +++- hot_fair_utilities/utils.py | 10 +++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/hot_fair_utilities/training/yolo_v8_v1/train.py b/hot_fair_utilities/training/yolo_v8_v1/train.py index 71f8965..ff5ea6f 100644 --- a/hot_fair_utilities/training/yolo_v8_v1/train.py +++ b/hot_fair_utilities/training/yolo_v8_v1/train.py @@ -7,9 +7,10 @@ import torch import ultralytics + # Reader imports from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight -from ...utils import compute_iou_chart_from_yolo_results, get_yolo_iou_metrics +from ...utils import compute_iou_chart_from_yolo_results, get_yolo_iou_metrics,export_model_to_onnx # Get environment variables with fallbacks # ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute())) # DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training"))) @@ -128,6 +129,7 @@ def train( output_model_path=os.path.join(os.path.join(output_path,"checkpoints"), name, "weights", "best.pt") iou_model_accuracy=get_yolo_iou_metrics(output_model_path) + export_model_to_onnx(output_model_path) return output_model_path,iou_model_accuracy diff --git a/hot_fair_utilities/training/yolo_v8_v2/train.py b/hot_fair_utilities/training/yolo_v8_v2/train.py index 8297ca7..80b960a 100644 --- a/hot_fair_utilities/training/yolo_v8_v2/train.py +++ b/hot_fair_utilities/training/yolo_v8_v2/train.py @@ -6,7 +6,7 @@ # Third party imports import torch import ultralytics -from ...utils import get_yolo_iou_metrics,compute_iou_chart_from_yolo_results +from ...utils import get_yolo_iou_metrics,compute_iou_chart_from_yolo_results,export_model_to_onnx # Reader imports from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight @@ -73,6 +73,7 @@ def train(data, weights, epochs, batch_size, pc, output_path, dataset_yaml_path, weights, resume = check4checkpoint(name, weights,output_path) model = yolo(weights) + model.train( data=data_scn, project=os.path.join(output_path,"checkpoints"), # Using the environment variable with fallback @@ -93,6 +94,7 @@ def train(data, weights, epochs, batch_size, pc, output_path, dataset_yaml_path, output_model_path=os.path.join(os.path.join(output_path,"checkpoints"), name, "weights", "best.pt") iou_model_accuracy=get_yolo_iou_metrics(output_model_path) + export_model_to_onnx(output_model_path) return output_model_path,iou_model_accuracy diff --git a/hot_fair_utilities/utils.py b/hot_fair_utilities/utils.py index 1ce759d..f9c86a7 100644 --- a/hot_fair_utilities/utils.py +++ b/hot_fair_utilities/utils.py @@ -287,4 +287,12 @@ def get_yolo_iou_metrics(model_path): - 1 ) # ref here https://github.com/ultralytics/ultralytics/issues/9984#issuecomment-2422551315 final_accuracy = iou_accuracy * 100 - return final_accuracy \ No newline at end of file + return final_accuracy + + + +def export_model_to_onnx(model_path): + model = ultralytics.YOLO(model_path) + model.export(format='onnx',imgsz=[256,256]) + # model.export(format='tflite') + return True \ No newline at end of file