Skip to content

Commit

Permalink
fix(onnx): supports onnx output of the yolo models for the inference
Browse files Browse the repository at this point in the history
  • Loading branch information
kshitijrajsharma committed Nov 25, 2024
1 parent 6fdaa18 commit b608d80
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
4 changes: 3 additions & 1 deletion hot_fair_utilities/training/yolo_v8_v1/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
import torch
import ultralytics


# Reader imports
from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight
from ...utils import compute_iou_chart_from_yolo_results, get_yolo_iou_metrics
from ...utils import compute_iou_chart_from_yolo_results, get_yolo_iou_metrics,export_model_to_onnx
# Get environment variables with fallbacks
# ROOT = Path(os.getenv("YOLO_ROOT", Path(__file__).parent.absolute()))
# DATA_ROOT = str(Path(os.getenv("YOLO_DATA_ROOT", ROOT / "yolo-training")))
Expand Down Expand Up @@ -128,6 +129,7 @@ def train(
output_model_path=os.path.join(os.path.join(output_path,"checkpoints"), name, "weights", "best.pt")

iou_model_accuracy=get_yolo_iou_metrics(output_model_path)
export_model_to_onnx(output_model_path)

return output_model_path,iou_model_accuracy

Expand Down
4 changes: 3 additions & 1 deletion hot_fair_utilities/training/yolo_v8_v2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# Third party imports
import torch
import ultralytics
from ...utils import get_yolo_iou_metrics,compute_iou_chart_from_yolo_results
from ...utils import get_yolo_iou_metrics,compute_iou_chart_from_yolo_results,export_model_to_onnx
# Reader imports
from hot_fair_utilities.model.yolo import YOLOSegWithPosWeight

Expand Down Expand Up @@ -73,6 +73,7 @@ def train(data, weights, epochs, batch_size, pc, output_path, dataset_yaml_path,

weights, resume = check4checkpoint(name, weights,output_path)
model = yolo(weights)

model.train(
data=data_scn,
project=os.path.join(output_path,"checkpoints"), # Using the environment variable with fallback
Expand All @@ -93,6 +94,7 @@ def train(data, weights, epochs, batch_size, pc, output_path, dataset_yaml_path,
output_model_path=os.path.join(os.path.join(output_path,"checkpoints"), name, "weights", "best.pt")

iou_model_accuracy=get_yolo_iou_metrics(output_model_path)
export_model_to_onnx(output_model_path)

return output_model_path,iou_model_accuracy

Expand Down
10 changes: 9 additions & 1 deletion hot_fair_utilities/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,12 @@ def get_yolo_iou_metrics(model_path):
- 1
) # ref here https://github.com/ultralytics/ultralytics/issues/9984#issuecomment-2422551315
final_accuracy = iou_accuracy * 100
return final_accuracy
return final_accuracy



def export_model_to_onnx(model_path):
model = ultralytics.YOLO(model_path)
model.export(format='onnx',imgsz=[256,256])
# model.export(format='tflite')
return True

0 comments on commit b608d80

Please sign in to comment.