From e0c54be2bd77d1bbc08531edcbbeccdc4cfa0e0e Mon Sep 17 00:00:00 2001 From: Eugene Liu Date: Wed, 3 Apr 2024 12:01:22 +0100 Subject: [PATCH] add fix --- src/otx/core/model/entity/base.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/otx/core/model/entity/base.py b/src/otx/core/model/entity/base.py index fc7dd604d22..cf506e3170c 100644 --- a/src/otx/core/model/entity/base.py +++ b/src/otx/core/model/entity/base.py @@ -15,6 +15,7 @@ import openvino from jsonargparse import ArgumentParser from openvino.model_api.models import Model +from openvino.model_api.tilers import Tiler from torch import nn from otx.core.data.dataset.base import LabelInfo @@ -463,8 +464,10 @@ def transform_fn(self, data_batch: T_OTXBatchDataEntity) -> np.array: """Data transform function for PTQ.""" np_data = self._customize_inputs(data_batch) image = np_data["inputs"][0] - resized_image = self.model.resize(image, (self.model.w, self.model.h)) - resized_image = self.model.input_transform(resized_image) + # NOTE: Tiler wraps the model, so we need to unwrap it to get the model + model = self.model.model if isinstance(self.model, Tiler) else self.model + resized_image = model.resize(image, (model.w, model.h)) + resized_image = model.input_transform(resized_image) return self.model._change_layout(resized_image) # noqa: SLF001 def _read_ptq_config_from_ir(self, ov_model: Model) -> dict[str, Any]: