From eb1a77335f2e7483804e1977c6b2cab24d4003cd Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Fri, 12 Apr 2024 17:27:37 +0900 Subject: [PATCH 1/5] Fix Signed-off-by: Kim, Vinnam --- .../algo/classification/torchvision_model.py | 14 +- src/otx/algo/segmentation/litehrnet.py | 584 +++++++++--------- src/otx/algo/segmentation/segnext.py | 2 +- src/otx/core/exporter/base.py | 19 +- src/otx/core/exporter/mmdeploy.py | 18 +- src/otx/core/exporter/native.py | 11 +- src/otx/core/exporter/visual_prompting.py | 21 +- src/otx/core/model/action_classification.py | 6 +- src/otx/core/model/base.py | 58 +- src/otx/core/model/classification.py | 14 +- src/otx/core/model/segmentation.py | 6 +- src/otx/core/model/utils/mmpretrain.py | 63 +- tests/integration/api/test_xai.py | 2 + tests/unit/core/exporter/test_mmdeploy.py | 2 +- .../core/exporter/test_visual_prompting.py | 24 +- 15 files changed, 435 insertions(+), 409 deletions(-) diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index 210fea7def4..0fe311f5a74 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any, Callable, Literal import torch -from torch import nn +from torch import Tensor, nn from torchvision.models import get_model, get_model_weights from otx.algo.explain.explain_algo import ReciproCAM @@ -312,11 +312,9 @@ def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassCls feature_vector=outputs["feature_vector"], ) - def _reset_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - self.__orig_model_forward = self.model.forward - self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001 + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") - def _restore_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - self.model.forward = self.__orig_model_forward # type: ignore[method-assign] + return self.model(images=image, mode="tensor") diff --git a/src/otx/algo/segmentation/litehrnet.py b/src/otx/algo/segmentation/litehrnet.py index 1116ae74c26..49122a9bd8d 100644 --- a/src/otx/algo/segmentation/litehrnet.py +++ b/src/otx/algo/segmentation/litehrnet.py @@ -90,89 +90,89 @@ def _obtain_ignored_scope(self) -> dict[str, Any]: """Returns the ignored scope for the model based on the litehrnet version.""" if self.model_name == "litehrnet_18": ignored_scope_names = [ - "/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.0/Add_1", - "/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.1/Add_1", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.0/Add_1", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.0/Add_2", - "/backbone/stage1/stage1.0/Add_5", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.1/Add_1", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.1/Add_2", - "/backbone/stage1/stage1.1/Add_5", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.2/Add_1", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.2/Add_2", - "/backbone/stage1/stage1.2/Add_5", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.3/Add_1", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.3/Add_2", - "/backbone/stage1/stage1.3/Add_5", - "/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.0/Add_1", - "/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.0/Add_2", - "/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.0/Add_3", - "/backbone/stage2/stage2.0/Add_6", - "/backbone/stage2/stage2.0/Add_7", - "/backbone/stage2/stage2.0/Add_11", - "/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.1/Add_1", - "/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.1/Add_2", - "/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.1/Add_3", - "/backbone/stage2/stage2.1/Add_6", - "/backbone/stage2/stage2.1/Add_7", - "/backbone/stage2/stage2.1/Add_11", - "/aggregator/Add", - "/aggregator/Add_1", - "/aggregator/Add_2", - "/backbone/stage2/stage2.1/Add", + "/model/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.0/Add_1", + "/model/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.1/Add_1", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.0/Add_1", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.0/Add_2", + "/model/backbone/stage1/stage1.0/Add_5", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.1/Add_1", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.1/Add_2", + "/model/backbone/stage1/stage1.1/Add_5", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.2/Add_1", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.2/Add_2", + "/model/backbone/stage1/stage1.2/Add_5", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.3/Add_1", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.3/Add_2", + "/model/backbone/stage1/stage1.3/Add_5", + "/model/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.0/Add_1", + "/model/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.0/Add_2", + "/model/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.0/Add_3", + "/model/backbone/stage2/stage2.0/Add_6", + "/model/backbone/stage2/stage2.0/Add_7", + "/model/backbone/stage2/stage2.0/Add_11", + "/model/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.1/Add_1", + "/model/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.1/Add_2", + "/model/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.1/Add_3", + "/model/backbone/stage2/stage2.1/Add_6", + "/model/backbone/stage2/stage2.1/Add_7", + "/model/backbone/stage2/stage2.1/Add_11", + "/model/aggregator/Add", + "/model/aggregator/Add_1", + "/model/aggregator/Add_2", + "/model/backbone/stage2/stage2.1/Add", ] return { "ignored_scope": { - "patterns": ["/backbone/*"], + "patterns": ["/model/backbone/*"], "names": ignored_scope_names, }, "preset": "mixed", @@ -180,64 +180,64 @@ def _obtain_ignored_scope(self) -> dict[str, Any]: if self.model_name == "litehrnet_s": ignored_scope_names = [ - "/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.0/Add_1", - "/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.1/Add_1", - "/backbone/stage0/stage0.2/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.2/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.2/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.2/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.2/Add_1", - "/backbone/stage0/stage0.3/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.3/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.3/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.3/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.3/Add_1", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.0/Add_1", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.0/Add_2", - "/backbone/stage1/stage1.0/Add_5", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.1/Add_1", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.1/Add_2", - "/backbone/stage1/stage1.1/Add_5", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.2/Add_1", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.2/Add_2", - "/backbone/stage1/stage1.2/Add_5", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.3/Add_1", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.3/Add_2", - "/backbone/stage1/stage1.3/Add_5", - "/aggregator/Add", - "/aggregator/Add_1", + "/model/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.0/Add_1", + "/model/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.1/Add_1", + "/model/backbone/stage0/stage0.2/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.2/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.2/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.2/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.2/Add_1", + "/model/backbone/stage0/stage0.3/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.3/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.3/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.3/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.3/Add_1", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.0/Add_1", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.0/Add_2", + "/model/backbone/stage1/stage1.0/Add_5", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.1/Add_1", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.1/Add_2", + "/model/backbone/stage1/stage1.1/Add_5", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.2/Add_1", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.2/Add_2", + "/model/backbone/stage1/stage1.2/Add_5", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.3/Add_1", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.3/Add_2", + "/model/backbone/stage1/stage1.3/Add_5", + "/model/aggregator/Add", + "/model/aggregator/Add_1", ] return { @@ -249,165 +249,165 @@ def _obtain_ignored_scope(self) -> dict[str, Any]: if self.model_name == "litehrnet_x": ignored_scope_names = [ - "/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.0/Add_1", - "/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage0/stage0.1/Add_1", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.0/Add_1", - "/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.0/Add_2", - "/backbone/stage1/stage1.0/Add_5", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.1/Add_1", - "/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.1/Add_2", - "/backbone/stage1/stage1.1/Add_5", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.2/Add_1", - "/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.2/Add_2", - "/backbone/stage1/stage1.2/Add_5", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage1/stage1.3/Add_1", - "/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage1/stage1.3/Add_2", - "/backbone/stage1/stage1.3/Add_5", - "/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.0/Add_1", - "/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.0/Add_2", - "/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.0/Add_3", - "/backbone/stage2/stage2.0/Add_6", - "/backbone/stage2/stage2.0/Add_7", - "/backbone/stage2/stage2.0/Add_11", - "/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.1/Add_1", - "/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.1/Add_2", - "/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.1/Add_3", - "/backbone/stage2/stage2.1/Add_6", - "/backbone/stage2/stage2.1/Add_7", - "/backbone/stage2/stage2.1/Add_11", - "/backbone/stage2/stage2.2/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.2/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.2/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.2/layers/layers.0/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.2/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.2/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.2/Add_1", - "/backbone/stage2/stage2.2/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.2/Add_2", - "/backbone/stage2/stage2.2/layers/layers.1/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.2/Add_3", - "/backbone/stage2/stage2.2/Add_6", - "/backbone/stage2/stage2.2/Add_7", - "/backbone/stage2/stage2.2/Add_11", - "/backbone/stage2/stage2.3/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.3/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.3/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.3/layers/layers.0/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.3/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage2/stage2.3/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage2/stage2.3/Add_1", - "/backbone/stage2/stage2.3/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage2/stage2.3/Add_2", - "/backbone/stage2/stage2.3/layers/layers.1/cross_resolution_weighting/Mul_3", - "/backbone/stage2/stage2.3/Add_3", - "/backbone/stage2/stage2.3/Add_6", - "/backbone/stage2/stage2.3/Add_7", - "/backbone/stage2/stage2.3/Add_11", - "/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul_3", - "/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul_4", - "/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage3/stage3.0/Add_1", - "/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage3/stage3.0/Add_2", - "/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul_3", - "/backbone/stage3/stage3.0/Add_3", - "/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul_4", - "/backbone/stage3/stage3.0/Add_4", - "/backbone/stage3/stage3.0/Add_7", - "/backbone/stage3/stage3.0/Add_8", - "/backbone/stage3/stage3.0/Add_9", - "/backbone/stage3/stage3.0/Add_13", - "/backbone/stage3/stage3.0/Add_14", - "/backbone/stage3/stage3.0/Add_19", - "/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul", - "/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul_1", - "/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul_2", - "/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul_3", - "/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul_4", - "/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul", - "/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul_1", - "/backbone/stage3/stage3.1/Add_1", - "/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul_2", - "/backbone/stage3/stage3.1/Add_2", - "/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul_3", - "/backbone/stage3/stage3.1/Add_3", - "/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul_4", - "/backbone/stage3/stage3.1/Add_4", - "/backbone/stage3/stage3.1/Add_7", - "/backbone/stage3/stage3.1/Add_8", - "/backbone/stage3/stage3.1/Add_9", - "/backbone/stage3/stage3.1/Add_13", - "/backbone/stage3/stage3.1/Add_14", - "/backbone/stage3/stage3.1/Add_19", - "/backbone/stage0/stage0.0/Add", - "/backbone/stage0/stage0.1/Add", - "/backbone/stage1/stage1.0/Add", - "/backbone/stage1/stage1.1/Add", - "/backbone/stage1/stage1.2/Add", - "/backbone/stage1/stage1.3/Add", - "/backbone/stage2/stage2.0/Add", - "/backbone/stage2/stage2.1/Add", - "/backbone/stage2/stage2.2/Add", - "/backbone/stage2/stage2.3/Add", - "/backbone/stage3/stage3.0/Add", - "/backbone/stage3/stage3.1/Add", + "/model/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.0/Add_1", + "/model/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage0/stage0.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage0/stage0.1/Add_1", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.0/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.0/Add_1", + "/model/backbone/stage1/stage1.0/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.0/Add_2", + "/model/backbone/stage1/stage1.0/Add_5", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.1/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.1/Add_1", + "/model/backbone/stage1/stage1.1/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.1/Add_2", + "/model/backbone/stage1/stage1.1/Add_5", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.2/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.2/Add_1", + "/model/backbone/stage1/stage1.2/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.2/Add_2", + "/model/backbone/stage1/stage1.2/Add_5", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.3/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage1/stage1.3/Add_1", + "/model/backbone/stage1/stage1.3/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage1/stage1.3/Add_2", + "/model/backbone/stage1/stage1.3/Add_5", + "/model/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.0/layers/layers.0/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.0/Add_1", + "/model/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.0/Add_2", + "/model/backbone/stage2/stage2.0/layers/layers.1/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.0/Add_3", + "/model/backbone/stage2/stage2.0/Add_6", + "/model/backbone/stage2/stage2.0/Add_7", + "/model/backbone/stage2/stage2.0/Add_11", + "/model/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.1/layers/layers.0/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.1/Add_1", + "/model/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.1/Add_2", + "/model/backbone/stage2/stage2.1/layers/layers.1/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.1/Add_3", + "/model/backbone/stage2/stage2.1/Add_6", + "/model/backbone/stage2/stage2.1/Add_7", + "/model/backbone/stage2/stage2.1/Add_11", + "/model/backbone/stage2/stage2.2/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.2/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.2/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.2/layers/layers.0/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.2/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.2/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.2/Add_1", + "/model/backbone/stage2/stage2.2/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.2/Add_2", + "/model/backbone/stage2/stage2.2/layers/layers.1/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.2/Add_3", + "/model/backbone/stage2/stage2.2/Add_6", + "/model/backbone/stage2/stage2.2/Add_7", + "/model/backbone/stage2/stage2.2/Add_11", + "/model/backbone/stage2/stage2.3/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.3/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.3/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.3/layers/layers.0/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.3/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage2/stage2.3/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage2/stage2.3/Add_1", + "/model/backbone/stage2/stage2.3/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage2/stage2.3/Add_2", + "/model/backbone/stage2/stage2.3/layers/layers.1/cross_resolution_weighting/Mul_3", + "/model/backbone/stage2/stage2.3/Add_3", + "/model/backbone/stage2/stage2.3/Add_6", + "/model/backbone/stage2/stage2.3/Add_7", + "/model/backbone/stage2/stage2.3/Add_11", + "/model/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul_3", + "/model/backbone/stage3/stage3.0/layers/layers.0/cross_resolution_weighting/Mul_4", + "/model/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage3/stage3.0/Add_1", + "/model/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage3/stage3.0/Add_2", + "/model/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul_3", + "/model/backbone/stage3/stage3.0/Add_3", + "/model/backbone/stage3/stage3.0/layers/layers.1/cross_resolution_weighting/Mul_4", + "/model/backbone/stage3/stage3.0/Add_4", + "/model/backbone/stage3/stage3.0/Add_7", + "/model/backbone/stage3/stage3.0/Add_8", + "/model/backbone/stage3/stage3.0/Add_9", + "/model/backbone/stage3/stage3.0/Add_13", + "/model/backbone/stage3/stage3.0/Add_14", + "/model/backbone/stage3/stage3.0/Add_19", + "/model/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul", + "/model/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul_1", + "/model/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul_2", + "/model/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul_3", + "/model/backbone/stage3/stage3.1/layers/layers.0/cross_resolution_weighting/Mul_4", + "/model/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul", + "/model/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul_1", + "/model/backbone/stage3/stage3.1/Add_1", + "/model/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul_2", + "/model/backbone/stage3/stage3.1/Add_2", + "/model/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul_3", + "/model/backbone/stage3/stage3.1/Add_3", + "/model/backbone/stage3/stage3.1/layers/layers.1/cross_resolution_weighting/Mul_4", + "/model/backbone/stage3/stage3.1/Add_4", + "/model/backbone/stage3/stage3.1/Add_7", + "/model/backbone/stage3/stage3.1/Add_8", + "/model/backbone/stage3/stage3.1/Add_9", + "/model/backbone/stage3/stage3.1/Add_13", + "/model/backbone/stage3/stage3.1/Add_14", + "/model/backbone/stage3/stage3.1/Add_19", + "/model/backbone/stage0/stage0.0/Add", + "/model/backbone/stage0/stage0.1/Add", + "/model/backbone/stage1/stage1.0/Add", + "/model/backbone/stage1/stage1.1/Add", + "/model/backbone/stage1/stage1.2/Add", + "/model/backbone/stage1/stage1.3/Add", + "/model/backbone/stage2/stage2.0/Add", + "/model/backbone/stage2/stage2.1/Add", + "/model/backbone/stage2/stage2.2/Add", + "/model/backbone/stage2/stage2.3/Add", + "/model/backbone/stage3/stage3.0/Add", + "/model/backbone/stage3/stage3.1/Add", ] return { "ignored_scope": { - "patterns": ["/aggregator/*"], + "patterns": ["/model/aggregator/*"], "names": ignored_scope_names, }, "preset": "performance", diff --git a/src/otx/algo/segmentation/segnext.py b/src/otx/algo/segmentation/segnext.py index 03c2afaaffa..b042a74c848 100644 --- a/src/otx/algo/segmentation/segnext.py +++ b/src/otx/algo/segmentation/segnext.py @@ -73,7 +73,7 @@ def _optimization_config(self) -> dict[str, Any]: # TODO(Kirill): check PTQ removing hamburger from ignored_scope return { "ignored_scope": { - "patterns": ["__module.decode_head.hamburger*"], + "patterns": ["__module.model.decode_head.hamburger*"], "types": [ "Add", "MVN", diff --git a/src/otx/core/exporter/base.py b/src/otx/core/exporter/base.py index 410f4a898d0..2a6b8735eaf 100644 --- a/src/otx/core/exporter/base.py +++ b/src/otx/core/exporter/base.py @@ -22,7 +22,8 @@ if TYPE_CHECKING: import onnx import openvino - import torch + + from otx.core.model.base import OTXModel class OTXModelExporter: @@ -74,7 +75,7 @@ def metadata(self) -> dict[tuple[str, str], str]: def export( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", export_format: OTXExportFormatType = OTXExportFormatType.OPENVINO, @@ -83,7 +84,7 @@ def export( """Exports input model to the specified deployable format, such as OpenVINO IR or ONNX. Args: - model (torch.nn.Module): pytorch model top export + model (OTXModel): OTXModel to be exported output_dir (Path): path to the directory to store export artifacts base_model_name (str, optional): exported model name format (OTXExportFormatType): final format of the exported model @@ -110,7 +111,7 @@ def export( @abstractmethod def to_openvino( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -118,7 +119,7 @@ def to_openvino( """Export to OpenVINO Intermediate Representation format. Args: - model (torch.nn.Module): pytorch model top export + model (OTXModel): OTXModel to be exported output_dir (Path): path to the directory to store export artifacts base_model_name (str, optional): exported model name precision (OTXExportPrecisionType, optional): precision of the exported model's weights @@ -130,7 +131,7 @@ def to_openvino( @abstractmethod def to_onnx( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -141,7 +142,7 @@ def to_onnx( Converts the given torch model to ONNX format and saves it to the specified output directory. Args: - model (torch.nn.Module): The input PyTorch model to be converted. + model (OTXModel): The input PyTorch model to be converted. output_dir (Path): The directory where the ONNX model will be saved. base_model_name (str, optional): The name of the exported ONNX model. Defaults to "exported_model". precision (OTXPrecisionType, optional): The precision type for the exported model. @@ -154,7 +155,7 @@ def to_onnx( def to_exportable_code( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -162,7 +163,7 @@ def to_exportable_code( """Export to zip folder final OV IR model with runable demo. Args: - model (torch.nn.Module): pytorch model top export + model (OTXModel): OTXModel to be exported output_dir (Path): path to the directory to store export artifacts base_model_name (str, optional): exported model name precision (OTXExportPrecisionType, optional): precision of the exported model's weights diff --git a/src/otx/core/exporter/mmdeploy.py b/src/otx/core/exporter/mmdeploy.py index fd2abe90bf3..7af26ce46e6 100644 --- a/src/otx/core/exporter/mmdeploy.py +++ b/src/otx/core/exporter/mmdeploy.py @@ -30,6 +30,8 @@ from omegaconf import DictConfig + from otx.core.model.base import OTXModel + class MMdeployExporter(OTXModelExporter): """Exporter that uses mmdeploy and OpenVINO conversion tools. @@ -102,7 +104,7 @@ def _set_max_num_detections(self, max_num_detections: int) -> None: def to_openvino( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -110,7 +112,7 @@ def to_openvino( """Export to OpenVINO Intermediate Representation format. Args: - model (torch.nn.Module): pytorch model top export + model (OTXModel): OTXModel to be exported output_dir (Path): path to the directory to store export artifacts base_model_name (str, optional): exported model name precision (OTXPrecisionType, optional): precision of the exported model's weights @@ -135,7 +137,7 @@ def to_openvino( def to_onnx( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -144,7 +146,7 @@ def to_onnx( """Export to ONNX format. Args: - model (torch.nn.Module): pytorch model top export + model (OTXModel): OTXModel to be exported output_dir (Path): path to the directory to store export artifacts base_model_name (str, optional): exported model name precision (OTXPrecisionType, optional): precision of the exported model's weights @@ -171,14 +173,17 @@ def _prepare_onnx_cfg(self) -> MMConfig: def _cvt2onnx( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str, deploy_cfg: MMConfig | None = None, ) -> Path: onnx_file_name = base_model_name + ".onnx" model_weight_file = output_dir / "mmdeploy_fmt_model.pth" - torch.save(model.state_dict(), model_weight_file) + # NOTE: This class doesn't actuall use the given model instance for graph tracing. + # It just borrows weights of the given model instance. + mm_model = model.model + torch.save(mm_model.state_dict(), model_weight_file) log.debug(f"mmdeploy torch2onnx: \n\tmodel_cfg: {self._model_cfg}\n\tdeploy_cfg: {self._deploy_cfg}") with use_temporary_default_scope(): @@ -188,6 +193,7 @@ def _cvt2onnx( str(output_dir), onnx_file_name, deploy_cfg=self._deploy_cfg if deploy_cfg is None else deploy_cfg, + # NOTE: The actual model instance for graph tracing is created by this `model_cfg`. model_cfg=self._model_cfg, model_checkpoint=str(model_weight_file), device="cpu", diff --git a/src/otx/core/exporter/native.py b/src/otx/core/exporter/native.py index 27a7e60eee2..11f90b9451d 100644 --- a/src/otx/core/exporter/native.py +++ b/src/otx/core/exporter/native.py @@ -8,7 +8,7 @@ import logging as log import tempfile from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal import onnx import openvino @@ -18,6 +18,9 @@ from otx.core.types.export import TaskLevelExportParameters from otx.core.types.precision import OTXPrecisionType +if TYPE_CHECKING: + from otx.core.model.base import OTXModel + class OTXNativeModelExporter(OTXModelExporter): """Exporter that uses native torch and OpenVINO conversion tools.""" @@ -52,7 +55,7 @@ def __init__( def to_openvino( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -94,7 +97,7 @@ def to_openvino( def to_onnx( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -103,7 +106,7 @@ def to_onnx( """Export the given PyTorch model to ONNX format and save it to the specified output directory. Args: - model (torch.nn.Module): The PyTorch model to be exported. + model (OTXModel): The PyTorch model to be exported. output_dir (Path): The directory where the ONNX model will be saved. base_model_name (str, optional): The base name for the exported model. Defaults to "exported_model". precision (OTXPrecisionType, optional): The precision type for the exported model. diff --git a/src/otx/core/exporter/visual_prompting.py b/src/otx/core/exporter/visual_prompting.py index 67d3de5d74c..b04fd6105c5 100644 --- a/src/otx/core/exporter/visual_prompting.py +++ b/src/otx/core/exporter/visual_prompting.py @@ -8,7 +8,7 @@ import logging as log import tempfile from pathlib import Path -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal import onnx import openvino @@ -18,13 +18,16 @@ from otx.core.types.export import OTXExportFormatType from otx.core.types.precision import OTXPrecisionType +if TYPE_CHECKING: + from otx.core.model.base import OTXModel + class OTXVisualPromptingModelExporter(OTXNativeModelExporter): """Exporter for visual prompting models that uses native torch and OpenVINO conversion tools.""" def export( # type: ignore[override] self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", export_format: OTXExportFormatType = OTXExportFormatType.OPENVINO, @@ -33,7 +36,7 @@ def export( # type: ignore[override] """Exports input model to the specified deployable format, such as OpenVINO IR or ONNX. Args: - model (torch.nn.Module): pytorch model top export + model (OTXModel): OTXModel to be exported output_dir (Path): path to the directory to store export artifacts base_model_name (str, optional): exported model name format (OTXExportFormatType): final format of the exported model @@ -42,9 +45,11 @@ def export( # type: ignore[override] Returns: dict[str, Path]: paths to the exported models """ + # NOTE: Rather than using OTXModel.forward_for_tracing() + # Use the nested `image_encoder` and `decoder` models' forward functions directly models: dict[str, torch.nn.Module] = { - "image_encoder": model.image_encoder, - "decoder": model, + "image_encoder": model.model.image_encoder, + "decoder": model.model, } if export_format == OTXExportFormatType.OPENVINO: @@ -65,7 +70,7 @@ def export( # type: ignore[override] def to_openvino( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -99,7 +104,7 @@ def to_openvino( def to_onnx( self, - model: torch.nn.Module, + model: OTXModel, output_dir: Path, base_model_name: str = "exported_model", precision: OTXPrecisionType = OTXPrecisionType.FP32, @@ -108,7 +113,7 @@ def to_onnx( """Export the given PyTorch model to ONNX format and save it to the specified output directory. Args: - model (torch.nn.Module): The PyTorch model to be exported. + model (OTXModel): OTXModel to be exported. output_dir (Path): The directory where the ONNX model will be saved. base_model_name (str, optional): The base name for the exported model. Defaults to "exported_model". precision (OTXPrecisionType, optional): The precision type for the exported model. diff --git a/src/otx/core/model/action_classification.py b/src/otx/core/model/action_classification.py index 0c662f0fea0..59228945761 100644 --- a/src/otx/core/model/action_classification.py +++ b/src/otx/core/model/action_classification.py @@ -26,7 +26,7 @@ from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from omegaconf import DictConfig from openvino.model_api.models.utils import ClassificationResult - from torch import nn + from torch import Tensor, nn from otx.core.exporter.base import OTXModelExporter from otx.core.metrics import MetricCallable @@ -190,6 +190,10 @@ def _exporter(self) -> OTXModelExporter: output_names=None, ) + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model(inputs=image, mode="tensor") + class OVActionClsModel( OVModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity], diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index a4bf4943e0e..0bf0c4312a2 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -19,7 +19,7 @@ import openvino import torch from jsonargparse import ArgumentParser -from lightning import LightningModule +from lightning import LightningModule, Trainer from openvino.model_api.models import Model from openvino.model_api.tilers import Tiler from torch import Tensor, nn @@ -35,7 +35,6 @@ T_OTXBatchPredEntity, ) from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity -from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput, NullMetricCallable from otx.core.optimizer.callable import OptimizerCallableSupportHPO @@ -57,6 +56,7 @@ from torch.optim.optimizer import Optimizer, params_t from otx.core.data.module import OTXDataModule + from otx.core.exporter.base import OTXModelExporter from otx.core.metrics import MetricCallable logger = logging.getLogger() @@ -115,7 +115,7 @@ def __init__( self.torch_compile = torch_compile self._explain_mode = False - self._tile_config: TileConfig | None = None + self._tile_config = TileConfig(enable_tiler=False) # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt @@ -346,9 +346,7 @@ def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: checkpoint["label_info"] = self.label_info checkpoint["otx_version"] = __version__ - - if self._tile_config: - checkpoint["tile_config"] = self._tile_config + checkpoint["tile_config"] = self.tile_config def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: """Callback on loading checkpoint.""" @@ -358,7 +356,7 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: self._label_info = ckpt_label_info if ckpt_tile_config := checkpoint.get("tile_config", None): - self._tile_config = ckpt_tile_config + self.tile_config = ckpt_tile_config def load_state_dict_incrementally(self, ckpt: dict[str, Any], *args, **kwargs) -> None: """Load state dict incrementally.""" @@ -499,20 +497,21 @@ def forward( def forward_explain(self, inputs: T_OTXBatchDataEntity) -> T_OTXBatchPredEntity: """Model forward explain function.""" - raise NotImplementedError + msg = "Derived model class should implement this class to support the explain pipeline." + raise NotImplementedError(msg) + + def forward_for_tracing(self, *args, **kwargs) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + msg = ( + "Derived model class should implement this class to support the export pipeline. " + "If it wants to use `otx.core.exporter.native.OTXNativeModelExporter`." + ) + raise NotImplementedError(msg) def get_explain_fn(self) -> Callable: """Returns explain function.""" raise NotImplementedError - def _reset_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - pass - - def _restore_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - pass - def forward_tiles( self, inputs: T_OTXTileBatchDataEntity, @@ -610,16 +609,23 @@ def export( Returns: Path: path to the exported model. """ - self._reset_model_forward() - exported_model_path = self._exporter.export( - self.model, - output_dir, - base_name, - export_format, - precision, - ) - self._restore_model_forward() - return exported_model_path + mode = self.training + self.eval() + + orig_forward = self.forward + try: + self._trainer = Trainer() + self.forward = self.forward_for_tracing # type: ignore[method-assign, assignment] + return self._exporter.export( + self, + output_dir, + base_name, + export_format, + precision, + ) + finally: + self.train(mode) + self.forward = orig_forward # type: ignore[method-assign] @property def _exporter(self) -> OTXModelExporter: diff --git a/src/otx/core/model/classification.py b/src/otx/core/model/classification.py index a88c2252511..ed1210c368f 100644 --- a/src/otx/core/model/classification.py +++ b/src/otx/core/model/classification.py @@ -41,7 +41,7 @@ from mmpretrain.models.utils import ClsDataPreprocessor from omegaconf import DictConfig from openvino.model_api.models.utils import ClassificationResult - from torch import nn + from torch import Tensor, nn from otx.core.metrics import MetricCallable @@ -236,6 +236,10 @@ def _exporter(self) -> OTXModelExporter: output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, ) + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model.forward(image, mode="tensor") + ### NOTE, currently, although we've made the separate Multi-cls, Multi-label classes ### It'll be integrated after H-label classification integration with more advanced design. @@ -287,6 +291,10 @@ def _convert_pred_entity_to_compute_metric( "target": torch.stack(inputs.labels), } + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model.forward(image, mode="tensor") + class MMPretrainMultilabelClsModel(OTXMultilabelClsModel): """Multi-label Classification model compatible for MMPretrain. @@ -641,6 +649,10 @@ def _exporter(self) -> OTXModelExporter: output_names=["logits", "feature_vector", "saliency_map"] if self.explain_mode else None, ) + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model.forward(image, mode="tensor") + class OVMulticlassClassificationModel( OVModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity], diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index 3beb136a2fe..2bf9d6b4a3b 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -26,7 +26,7 @@ from mmseg.models.data_preprocessor import SegDataPreProcessor from omegaconf import DictConfig from openvino.model_api.models.utils import ImageResultWithSoftPrediction - from torch import nn + from torch import Tensor, nn from otx.core.metrics import MetricCallable @@ -173,6 +173,10 @@ def _customize_outputs( masks=masks, ) + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model(inputs=image, mode="tensor") + class OVSegmentationModel(OVModel[SegBatchDataEntity, SegBatchPredEntity]): """Semantic segmentation model compatible for OpenVINO IR inference. diff --git a/src/otx/core/model/utils/mmpretrain.py b/src/otx/core/model/utils/mmpretrain.py index 0b946cd71ac..1e2366cb8cb 100644 --- a/src/otx/core/model/utils/mmpretrain.py +++ b/src/otx/core/model/utils/mmpretrain.py @@ -20,7 +20,7 @@ from mmpretrain.models.classifiers.image import ImageClassifier from mmpretrain.structures import DataSample from omegaconf import DictConfig - from torch import device, nn + from torch import Tensor, device, nn @MODELS.register_module(force=True) @@ -141,49 +141,42 @@ def get_explain_fn(self) -> Callable: ) return explainer.func + def _register(self) -> None: + if getattr(self, "_registered", False): + return + self.model.feature_vector_fn = get_feature_vector + self.model.explain_fn = self.get_explain_fn() + self._registered = True + def forward_explain( self, inputs: T_OTXBatchDataEntity, ) -> T_OTXBatchPredEntity: """Model forward function.""" - forward_func: Callable[[T_OTXBatchDataEntity], T_OTXBatchPredEntity] | None = getattr(self, "forward", None) - - if forward_func is None: - msg = ( - "This instance has no forward function. " - "Did you attach this mixin into a class derived from OTXModel?" - ) - raise RuntimeError(msg) + self._register() + orig_model_forward = self.model.forward try: - self._reset_model_forward() - return forward_func(inputs) - finally: - self._restore_model_forward() + self.model.forward = types.MethodType(self._forward_explain_image_classifier, self.model) - def _reset_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - if not self.explain_mode: - return + forward_func: Callable[[T_OTXBatchDataEntity], T_OTXBatchPredEntity] | None = getattr(self, "forward", None) - self.model.feature_vector_fn = get_feature_vector - self.model.explain_fn = self.get_explain_fn() - forward_with_explain = self._forward_explain_image_classifier - - self.original_model_forward = self.model.forward - - func_type = types.MethodType - self.model.forward = func_type(forward_with_explain, self.model) + if forward_func is None: + msg = ( + "This instance has no forward function. " + "Did you attach this mixin into a class derived from OTXModel?" + ) + raise RuntimeError(msg) - def _restore_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - if not self.explain_mode: - return + return forward_func(inputs) + finally: + self.model.forward = orig_model_forward - if not self.original_model_forward: - msg = "Original model forward was not saved." - raise RuntimeError(msg) + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + self._register() + forward_explain = types.MethodType(self._forward_explain_image_classifier, self.model) + return forward_explain(inputs=image, mode="tensor") - func_type = types.MethodType - self.model.forward = func_type(self.original_model_forward, self.model) - self.original_model_forward = None + return self.model.forward(inputs=image, mode="tensor") diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index 26f7ace117d..5ab701c4d52 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -24,6 +24,7 @@ @pytest.mark.parametrize( "recipe", EXPLAIN_MODEL_LIST, + ids=lambda x: "/".join(Path(x).parts[-2:]), ) def test_forward_explain( recipe: str, @@ -70,6 +71,7 @@ def test_forward_explain( @pytest.mark.parametrize( "recipe", EXPLAIN_MODEL_LIST, + ids=lambda x: "/".join(Path(x).parts[-2:]), ) def test_predict_with_explain( recipe: str, diff --git a/tests/unit/core/exporter/test_mmdeploy.py b/tests/unit/core/exporter/test_mmdeploy.py index 8dcaa4033b4..4c036e82e1f 100644 --- a/tests/unit/core/exporter/test_mmdeploy.py +++ b/tests/unit/core/exporter/test_mmdeploy.py @@ -183,7 +183,7 @@ def test_cvt2onnx(self, mocker, mock_model, output_dir, base_model_name, mock_to assert output_dir / f"{base_model_name}.onnx" == exporter._cvt2onnx(mock_model, output_dir, base_model_name) mock_torch.save.assert_called_once() - assert mock_torch.save.call_args.args[0] == mock_model.state_dict() + assert mock_torch.save.call_args.args[0] == mock_model.model.state_dict() mock_use_temporary_default_scope.assert_called_once() mock_build_task_processor.assert_called_once() mock_torch2onnx.assert_called_once() diff --git a/tests/unit/core/exporter/test_visual_prompting.py b/tests/unit/core/exporter/test_visual_prompting.py index 9050a3e28d6..46931de8cc6 100644 --- a/tests/unit/core/exporter/test_visual_prompting.py +++ b/tests/unit/core/exporter/test_visual_prompting.py @@ -7,18 +7,6 @@ import pytest from otx.core.exporter.visual_prompting import OTXVisualPromptingModelExporter from otx.core.types.export import OTXExportFormatType -from torch import nn - - -class MockModel(nn.Module): - def __init__(self): - super().__init__() - self.image_encoder = nn.Identity() - self.embed_dim = 2 - self.image_embedding_size = 4 - - def forward(self, x): - return x class TestOTXVisualPromptingModelExporter: @@ -45,9 +33,10 @@ def test_export_openvino(self, mocker, tmpdir, otx_visual_prompting_model_export "_postprocess_openvino_model", ) mocker_openvino_save_model = mocker.patch("openvino.save_model") + mock_model = mocker.MagicMock() otx_visual_prompting_model_exporter.export( - model=MockModel(), + model=mock_model, output_dir=tmpdir, export_format=OTXExportFormatType.OPENVINO, ) @@ -69,9 +58,10 @@ def test_export_onnx(self, mocker, tmpdir, otx_visual_prompting_model_exporter) otx_visual_prompting_model_exporter, "_postprocess_onnx_model", ) + mock_model = mocker.MagicMock() otx_visual_prompting_model_exporter.export( - model=MockModel(), + model=mock_model, output_dir=tmpdir, export_format=OTXExportFormatType.ONNX, ) @@ -81,11 +71,13 @@ def test_export_onnx(self, mocker, tmpdir, otx_visual_prompting_model_exporter) mocker_onnx_save.assert_called() mocker_postprocess_onnx_model.assert_called() - def test_export_exportable_code(self, tmpdir, otx_visual_prompting_model_exporter) -> None: + def test_export_exportable_code(self, mocker, tmpdir, otx_visual_prompting_model_exporter) -> None: """Test export for EXPORTABLE_CODE.""" + mock_model = mocker.MagicMock() + with pytest.raises(NotImplementedError): otx_visual_prompting_model_exporter.export( - model=MockModel(), + model=mock_model, output_dir=tmpdir, export_format=OTXExportFormatType.EXPORTABLE_CODE, ) From 07e87e69e79ca1e914c5c3e7c28a970217f15ed5 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 16 Apr 2024 00:57:39 +0900 Subject: [PATCH 2/5] Revert test_xai.py Signed-off-by: Kim, Vinnam --- tests/integration/api/test_xai.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index 5ab701c4d52..26f7ace117d 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -24,7 +24,6 @@ @pytest.mark.parametrize( "recipe", EXPLAIN_MODEL_LIST, - ids=lambda x: "/".join(Path(x).parts[-2:]), ) def test_forward_explain( recipe: str, @@ -71,7 +70,6 @@ def test_forward_explain( @pytest.mark.parametrize( "recipe", EXPLAIN_MODEL_LIST, - ids=lambda x: "/".join(Path(x).parts[-2:]), ) def test_predict_with_explain( recipe: str, From 264286867ca9c906c113331402b7cf8bb7c5d257 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 16 Apr 2024 10:40:46 +0900 Subject: [PATCH 3/5] Fix Signed-off-by: Kim, Vinnam --- src/otx/algo/classification/otx_dino_v2.py | 6 +++++- src/otx/core/model/base.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/otx/algo/classification/otx_dino_v2.py b/src/otx/algo/classification/otx_dino_v2.py index 63f5c82f97c..55994341f5a 100644 --- a/src/otx/algo/classification/otx_dino_v2.py +++ b/src/otx/algo/classification/otx_dino_v2.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Any import torch -from torch import nn +from torch import Tensor, nn from otx.algo.utils.mmconfig import read_mmconfig from otx.core.data.entity.base import OTXBatchLossEntity @@ -162,3 +162,7 @@ def _exporter(self) -> OTXModelExporter: def _optimization_config(self) -> dict[str, Any]: """PTQ config for DinoV2Cls.""" return {"model_type": "transformer"} + + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + return self.model(image) diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 0bf0c4312a2..288784f8e90 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -613,8 +613,11 @@ def export( self.eval() orig_forward = self.forward + orig_trainer = self._trainer # type: ignore[has-type] + try: - self._trainer = Trainer() + if self._trainer is None: # type: ignore[has-type] + self._trainer = Trainer() self.forward = self.forward_for_tracing # type: ignore[method-assign, assignment] return self._exporter.export( self, @@ -626,6 +629,7 @@ def export( finally: self.train(mode) self.forward = orig_forward # type: ignore[method-assign] + self._trainer = orig_trainer @property def _exporter(self) -> OTXModelExporter: From 4a2c294bdee3fc40c11dfc6636ad65fbf10076de Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Tue, 16 Apr 2024 11:32:31 +0900 Subject: [PATCH 4/5] Fix Signed-off-by: Kim, Vinnam --- src/otx/algo/detection/atss.py | 29 +++++----- src/otx/algo/detection/rtmdet.py | 29 +++++----- src/otx/algo/detection/ssd.py | 29 +++++----- src/otx/algo/detection/yolox.py | 58 ++++++++++--------- .../algo/instance_segmentation/maskrcnn.py | 58 ++++++++++--------- .../algo/instance_segmentation/rtmdet_inst.py | 29 +++++----- src/otx/core/model/detection.py | 16 ++++- src/otx/core/model/instance_segmentation.py | 16 ++++- 8 files changed, 150 insertions(+), 114 deletions(-) diff --git a/src/otx/algo/detection/atss.py b/src/otx/algo/detection/atss.py index 91a4a1b5bbe..cd3c99dfdbb 100644 --- a/src/otx/algo/detection/atss.py +++ b/src/otx/algo/detection/atss.py @@ -57,20 +57,21 @@ def _exporter(self) -> OTXModelExporter: mean, std = get_mean_std_from_data_processing(self.config) - return MMdeployExporter( - model_builder=self._create_model, - model_cfg=deepcopy(self.config), - deploy_cfg="otx.algo.detection.mmdeploy.atss", - test_pipeline=self._make_fake_test_pipeline(), - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=mean, - std=std, - resize_mode="standard", - pad_value=0, - swap_rgb=False, - output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, - ) + with self.export_model_forward_context(): + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.atss", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" diff --git a/src/otx/algo/detection/rtmdet.py b/src/otx/algo/detection/rtmdet.py index 2965c32c22d..73da35d4e3d 100644 --- a/src/otx/algo/detection/rtmdet.py +++ b/src/otx/algo/detection/rtmdet.py @@ -57,20 +57,21 @@ def _exporter(self) -> OTXModelExporter: mean, std = get_mean_std_from_data_processing(self.config) - return MMdeployExporter( - model_builder=self._create_model, - model_cfg=deepcopy(self.config), - deploy_cfg="otx.algo.detection.mmdeploy.rtmdet", - test_pipeline=self._make_fake_test_pipeline(), - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=mean, - std=std, - resize_mode="fit_to_window_letterbox", - pad_value=114, - swap_rgb=False, - output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, - ) + with self.export_model_forward_context(): + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.rtmdet", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="fit_to_window_letterbox", + pad_value=114, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" diff --git a/src/otx/algo/detection/ssd.py b/src/otx/algo/detection/ssd.py index 6f37dfa5e21..43b4a3d896e 100644 --- a/src/otx/algo/detection/ssd.py +++ b/src/otx/algo/detection/ssd.py @@ -563,20 +563,21 @@ def _exporter(self) -> OTXModelExporter: mean, std = get_mean_std_from_data_processing(self.config) - return MMdeployExporter( - model_builder=self._create_model, - model_cfg=deepcopy(self.config), - deploy_cfg="otx.algo.detection.mmdeploy.ssd_mobilenetv2", - test_pipeline=self._make_fake_test_pipeline(), - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=mean, - std=std, - resize_mode="standard", - pad_value=0, - swap_rgb=False, - output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, - ) + with self.export_model_forward_context(): + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.ssd_mobilenetv2", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", + pad_value=0, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None: """Callback on load checkpoint.""" diff --git a/src/otx/algo/detection/yolox.py b/src/otx/algo/detection/yolox.py index 15b56b2bbae..965460ae4b8 100644 --- a/src/otx/algo/detection/yolox.py +++ b/src/otx/algo/detection/yolox.py @@ -57,20 +57,21 @@ def _exporter(self) -> OTXModelExporter: mean, std = get_mean_std_from_data_processing(self.config) - return MMdeployExporter( - model_builder=self._create_model, - model_cfg=deepcopy(self.config), - deploy_cfg="otx.algo.detection.mmdeploy.yolox", - test_pipeline=self._make_fake_test_pipeline(), - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=mean, - std=std, - resize_mode="fit_to_window_letterbox", - pad_value=114, - swap_rgb=True, - output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, - ) + with self.export_model_forward_context(): + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.yolox", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="fit_to_window_letterbox", + pad_value=114, + swap_rgb=True, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" @@ -109,17 +110,18 @@ def _exporter(self) -> OTXModelExporter: mean, std = get_mean_std_from_data_processing(self.config) - return MMdeployExporter( - model_builder=self._create_model, - model_cfg=deepcopy(self.config), - deploy_cfg="otx.algo.detection.mmdeploy.yolox_tiny", - test_pipeline=self._make_fake_test_pipeline(), - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=mean, - std=std, - resize_mode="fit_to_window_letterbox", - pad_value=114, - swap_rgb=False, - output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, - ) + with self.export_model_forward_context(): + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.detection.mmdeploy.yolox_tiny", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="fit_to_window_letterbox", + pad_value=114, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) diff --git a/src/otx/algo/instance_segmentation/maskrcnn.py b/src/otx/algo/instance_segmentation/maskrcnn.py index a25e4f86634..fb899ae904c 100644 --- a/src/otx/algo/instance_segmentation/maskrcnn.py +++ b/src/otx/algo/instance_segmentation/maskrcnn.py @@ -57,20 +57,21 @@ def _exporter(self) -> OTXModelExporter: mean, std = get_mean_std_from_data_processing(self.config) - return MMdeployExporter( - model_builder=self._create_model, - model_cfg=deepcopy(self.config), - deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn", - test_pipeline=self._make_fake_test_pipeline(), - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=mean, - std=std, - resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving - pad_value=0, - swap_rgb=False, - output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, - ) + with self.export_model_forward_context(): + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving + pad_value=0, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict: """Load the previous OTX ckpt according to OTX2.0.""" @@ -109,17 +110,18 @@ def _exporter(self) -> OTXModelExporter: mean, std = get_mean_std_from_data_processing(self.config) - return MMdeployExporter( - model_builder=self._create_model, - model_cfg=deepcopy(self.config), - deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn_swint", - test_pipeline=self._make_fake_test_pipeline(), - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=mean, - std=std, - resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving - pad_value=0, - swap_rgb=False, - output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, - ) + with self.export_model_forward_context(): + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn_swint", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving + pad_value=0, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) diff --git a/src/otx/algo/instance_segmentation/rtmdet_inst.py b/src/otx/algo/instance_segmentation/rtmdet_inst.py index 7a751bfe3a0..e545482b0cf 100644 --- a/src/otx/algo/instance_segmentation/rtmdet_inst.py +++ b/src/otx/algo/instance_segmentation/rtmdet_inst.py @@ -56,17 +56,18 @@ def _exporter(self) -> OTXModelExporter: mean, std = get_mean_std_from_data_processing(self.config) - return MMdeployExporter( - model_builder=self._create_model, - model_cfg=deepcopy(self.config), - deploy_cfg="otx.algo.instance_segmentation.mmdeploy.rtmdet_inst", - test_pipeline=self._make_fake_test_pipeline(), - task_level_export_parameters=self._export_parameters, - input_size=self.image_size, - mean=mean, - std=std, - resize_mode="fit_to_window_letterbox", - pad_value=114, - swap_rgb=False, - output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, - ) + with self.export_model_forward_context(): + return MMdeployExporter( + model_builder=self._create_model, + model_cfg=deepcopy(self.config), + deploy_cfg="otx.algo.instance_segmentation.mmdeploy.rtmdet_inst", + test_pipeline=self._make_fake_test_pipeline(), + task_level_export_parameters=self._export_parameters, + input_size=self.image_size, + mean=mean, + std=std, + resize_mode="fit_to_window_letterbox", + pad_value=114, + swap_rgb=False, + output_names=["feature_vector", "saliency_map"] if self.explain_mode else None, + ) diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index 8f4a609931c..6edbccbe280 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -7,7 +7,8 @@ import logging as log import types -from typing import TYPE_CHECKING, Any, Callable, Literal +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal import torch from openvino.model_api.models import Model @@ -263,6 +264,19 @@ def get_explain_fn(self) -> Callable: ) return explainer.func + @contextmanager + def export_model_forward_context(self) -> Iterator[None]: + """A context manager for managing the model's forward function during model exportation. + + It temporarily modifies the model's forward function to generate output sinks + for explain results during the model graph tracing. + """ + try: + self._reset_model_forward() + yield + finally: + self._restore_model_forward() + def _reset_model_forward(self) -> None: if not self.explain_mode: return diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index cc6936a8719..6e2e76dfafb 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -7,7 +7,8 @@ import logging as log import types -from typing import TYPE_CHECKING, Any, Callable, Literal +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal import numpy as np import torch @@ -293,6 +294,19 @@ def get_explain_fn(self) -> Callable: explainer = MaskRCNNExplainAlgo(num_classes=self.num_classes) return explainer.func + @contextmanager + def export_model_forward_context(self) -> Iterator[None]: + """A context manager for managing the model's forward function during model exportation. + + It temporarily modifies the model's forward function to generate output sinks + for explain results during the model graph tracing. + """ + try: + self._reset_model_forward() + yield + finally: + self._restore_model_forward() + def _reset_model_forward(self) -> None: if not self.explain_mode: return From 6b7efc0d140ae3c9288f23a88fc763599e15f9e7 Mon Sep 17 00:00:00 2001 From: "Kim, Vinnam" Date: Wed, 17 Apr 2024 19:06:42 +0900 Subject: [PATCH 5/5] Fix mobilenetv3 Signed-off-by: Kim, Vinnam --- src/otx/algo/classification/mobilenet_v3.py | 38 +++++++++------------ 1 file changed, 16 insertions(+), 22 deletions(-) diff --git a/src/otx/algo/classification/mobilenet_v3.py b/src/otx/algo/classification/mobilenet_v3.py index 75acb8d3b01..85150aac745 100644 --- a/src/otx/algo/classification/mobilenet_v3.py +++ b/src/otx/algo/classification/mobilenet_v3.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING, Any, Callable, Literal import torch -from torch import nn +from torch import Tensor, nn from otx.algo.classification.backbones import OTXMobileNetV3 from otx.algo.classification.classifier.base_classifier import ImageClassifier @@ -158,14 +158,12 @@ def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassCls feature_vector=outputs["feature_vector"], ) - def _reset_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - self.__orig_model_forward = self.model.forward - self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001 + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") - def _restore_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - self.model.forward = self.__orig_model_forward # type: ignore[method-assign] + return self.model(images=image, mode="tensor") class MobileNetV3ForMultilabelCls(OTXMultilabelClsModel): @@ -277,14 +275,12 @@ def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelCls feature_vector=outputs["feature_vector"], ) - def _reset_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - self.__orig_model_forward = self.model.forward - self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001 + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") - def _restore_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - self.model.forward = self.__orig_model_forward # type: ignore[method-assign] + return self.model(images=image, mode="tensor") class MobileNetV3ForHLabelCls(OTXHlabelClsModel): @@ -423,11 +419,9 @@ def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPre feature_vector=outputs["feature_vector"], ) - def _reset_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - self.__orig_model_forward = self.model.forward - self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001 + def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]: + """Model forward function used for the model tracing during model exportation.""" + if self.explain_mode: + return self.model(images=image, mode="explain") - def _restore_model_forward(self) -> None: - # TODO(vinnamkim): This will be revisited by the export refactoring - self.model.forward = self.__orig_model_forward # type: ignore[method-assign] + return self.model(images=image, mode="tensor")