From db4c43bafbed1e1298c4436b8c7e7c6e5b9f3eec Mon Sep 17 00:00:00 2001 From: Francesco Mattioli Date: Wed, 31 Jul 2024 15:13:27 +0200 Subject: [PATCH] Exported model batch size validation fix (#14845) Co-authored-by: Glenn Jocher --- ultralytics/engine/validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ultralytics/engine/validator.py b/ultralytics/engine/validator.py index 8a2765c98f3..4a40a88291f 100644 --- a/ultralytics/engine/validator.py +++ b/ultralytics/engine/validator.py @@ -136,8 +136,8 @@ def __call__(self, trainer=None, model=None): if engine: self.args.batch = model.batch_size elif not pt and not jit: - self.args.batch = 1 # export.py models default to batch-size 1 - LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models") + self.args.batch = model.metadata.get("batch", 1) # export.py models default to batch-size 1 + LOGGER.info(f"Setting batch={self.args.batch} input of shape ({self.args.batch}, 3, {imgsz}, {imgsz})") if str(self.args.data).split(".")[-1] in {"yaml", "yml"}: self.data = check_det_dataset(self.args.data)