Skip to content

Commit

Permalink
Fix tile CLI on releases/2.0.0 (#3229)
Browse files Browse the repository at this point in the history
* update tile cli

* add test

* update test

* Update src/otx/engine/engine.py

Co-authored-by: Harim Kang <[email protected]>

* update engine/unittest accodring to comments

* remove accidentally added files

* remove unused imports

---------

Co-authored-by: Harim Kang <[email protected]>
  • Loading branch information
eugene123tw and harimkang authored Mar 29, 2024
1 parent b6437d2 commit dbac83f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 6 deletions.
4 changes: 2 additions & 2 deletions src/otx/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def instantiate_model(self, model_config: Namespace) -> tuple:
Returns:
tuple: The model and optimizer and scheduler.
"""
from otx.core.model.entity.base import OTXModel
from otx.core.model.entity.base import OTXModel, OVModel

# Update num_classes
if not self.get_config_value(self.config_init, "disable_infer_num_classes", False):
Expand Down Expand Up @@ -437,7 +437,7 @@ def instantiate_model(self, model_config: Namespace) -> tuple:
self.config_init[self.subcommand]["model"] = model

# Update tile config due to adaptive tiling
if self.datamodule.config.tile_config.enable_tiler:
if not isinstance(model, OVModel) and self.datamodule.config.tile_config.enable_tiler:
if not hasattr(model, "tile_config"):
msg = "The model does not have a tile_config attribute. Please check if the model supports tiling."
raise AttributeError(msg)
Expand Down
10 changes: 8 additions & 2 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,12 @@ def test(

is_ir_ckpt = Path(str(checkpoint)).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

# NOTE: Re-initiate datamodule for OVModel as model API supports its own data pipeline.
if isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")

metric = metric if metric is not None else self._auto_configurator.get_metric()
lit_module = self._build_lightning_module(
model=model,
Expand Down Expand Up @@ -387,9 +390,12 @@ def predict(

is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

# NOTE: Re-initiate datamodule for OVModel as model API supports its own data pipeline.
if isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")

lit_module = self._build_lightning_module(
model=model,
optimizer=self.optimizer,
Expand Down
7 changes: 5 additions & 2 deletions src/otx/engine/utils/auto_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,12 +347,15 @@ def update_ov_subset_pipeline(self, datamodule: OTXDataModule, subset: str = "te
data_configuration = datamodule.config
ov_test_config = self._load_default_config(model_name="openvino_model")["data"]["config"][f"{subset}_subset"]
subset_config = getattr(data_configuration, f"{subset}_subset")
subset_config.batch_size = ov_test_config["batch_size"]
subset_config.transform_lib_type = ov_test_config["transform_lib_type"]
subset_config.transforms = ov_test_config["transforms"]
data_configuration.tile_config.enable_tiler = False
msg = (
f"For OpenVINO IR models, Update the following {subset} transforms: {subset_config.transforms}"
f"and transform_lib_type: {subset_config.transform_lib_type}"
f"For OpenVINO IR models, Update the following {subset} \n"
f"\t transforms: {subset_config.transforms} \n"
f"\t transform_lib_type: {subset_config.transform_lib_type} \n"
f"\t batch_size: {subset_config.batch_size} \n"
"And the tiler is disabled."
)
warn(msg, stacklevel=1)
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from unittest.mock import create_autospec

import pytest
from otx.algo.classification.efficientnet_b0 import EfficientNetB0ForMulticlassCls
from otx.algo.classification.torchvision_model import OTXTVModel
from otx.core.config.device import DeviceConfig
from otx.core.model.entity.base import OVModel
from otx.core.types.export import OTXExportFormatType
from otx.core.types.precision import OTXPrecisionType
from otx.engine import Engine
Expand Down Expand Up @@ -120,6 +122,9 @@ def test_testing_with_ov_model(self, fxt_engine, mocker) -> None:
mock_test.assert_called_once()
mock_torch_load.assert_not_called()

fxt_engine.model = create_autospec(OVModel)
fxt_engine.test(checkpoint="path/to/model.xml")

def test_prediction_after_training(self, fxt_engine, mocker) -> None:
mocker.patch("otx.engine.engine.OTXLitModule.load_state_dict")
mock_predict = mocker.patch("otx.engine.engine.Trainer.predict")
Expand All @@ -144,6 +149,9 @@ def test_prediction_with_ov_model(self, fxt_engine, mocker) -> None:
mock_predict.assert_called_once()
mock_torch_load.assert_not_called()

fxt_engine.model = create_autospec(OVModel)
fxt_engine.predict(checkpoint="path/to/model.xml")

def test_prediction_explain_mode(self, fxt_engine, mocker) -> None:
mocker.patch("otx.engine.engine.OTXLitModule.load_state_dict")
mock_explain = mocker.patch("otx.algo.utils.xai_utils.process_saliency_maps_in_pred_entity")
Expand Down

0 comments on commit dbac83f

Please sign in to comment.