Skip to content

Commit

Permalink
bring tile-cli-2.0.0 to develop branch
Browse files Browse the repository at this point in the history
  • Loading branch information
eugene123tw committed Apr 2, 2024
1 parent a5305b6 commit 3b2d948
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
8 changes: 3 additions & 5 deletions src/otx/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,15 +349,14 @@ def test(

is_ir_ckpt = Path(str(checkpoint)).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)
if self.device.accelerator != "cpu":
msg = "IR model supports inference only on CPU device. The device is changed automatic."
warn(msg, stacklevel=1)
self.device = DeviceType.cpu # type: ignore[assignment]

# NOTE: Re-initiate datamodule without tiling as model API supports its own tiling mechanism
if isinstance(model, OVModel) and isinstance(datamodule.subsets["test"], OTXTileDataset):
if isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")

# NOTE, trainer.test takes only lightning based checkpoint.
Expand Down Expand Up @@ -441,11 +440,10 @@ def predict(

is_ir_ckpt = checkpoint is not None and Path(checkpoint).suffix in [".xml", ".onnx"]
if is_ir_ckpt and not isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")
model = self._auto_configurator.get_ov_model(model_name=str(checkpoint), label_info=datamodule.label_info)

# NOTE: Re-initiate datamodule without tiling as model API supports its own tiling mechanism
if isinstance(model, OVModel) and isinstance(datamodule.subsets["test"], OTXTileDataset):
# NOTE: Re-initiate datamodule for OVModel as model API supports its own data pipeline.
if isinstance(model, OVModel):
datamodule = self._auto_configurator.update_ov_subset_pipeline(datamodule=datamodule, subset="test")

if checkpoint is not None and not is_ir_ckpt:
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

from pathlib import Path
from unittest.mock import create_autospec

import pytest
from otx.algo.classification.efficientnet_b0 import EfficientNetB0ForMulticlassCls
from otx.algo.classification.torchvision_model import OTXTVModel
from otx.core.config.device import DeviceConfig
from otx.core.types.export import OTXExportFormatType
from otx.core.model.entity.base import OVModel
from otx.core.types.precision import OTXPrecisionType
from otx.engine import Engine

Expand Down Expand Up @@ -123,6 +125,9 @@ def test_testing_with_ov_model(self, fxt_engine, mocker) -> None:
mock_test.assert_called_once()
mock_torch_load.assert_not_called()

fxt_engine.model = create_autospec(OVModel)
fxt_engine.test(checkpoint="path/to/model.xml")

def test_prediction_after_training(self, fxt_engine, mocker) -> None:
mocker.patch("otx.engine.engine.OTXModel.load_state_dict")
mock_predict = mocker.patch("otx.engine.engine.Trainer.predict")
Expand All @@ -137,6 +142,9 @@ def test_prediction_after_training(self, fxt_engine, mocker) -> None:
fxt_engine.predict(checkpoint="path/to/new/checkpoint")
mock_torch_load.assert_called_with("path/to/new/checkpoint")

fxt_engine.model = create_autospec(OVModel)
fxt_engine.predict(checkpoint="path/to/model.xml")

def test_prediction_with_ov_model(self, fxt_engine, mocker) -> None:
mock_predict = mocker.patch("otx.engine.engine.Trainer.predict")
mock_torch_load = mocker.patch("torch.load")
Expand Down

0 comments on commit 3b2d948

Please sign in to comment.