Skip to content

Commit

Permalink
Refactor model export part 2: Add a dedicated forward function for mo…
Browse files Browse the repository at this point in the history
…del export (#3317)

* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

* Revert test_xai.py

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix mobilenetv3

Signed-off-by: Kim, Vinnam <[email protected]>

---------

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim authored Apr 18, 2024
1 parent 62e94bd commit 2e6d225
Show file tree
Hide file tree
Showing 24 changed files with 608 additions and 546 deletions.
38 changes: 16 additions & 22 deletions src/otx/algo/classification/mobilenet_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import TYPE_CHECKING, Any, Callable, Literal

import torch
from torch import nn
from torch import Tensor, nn

from otx.algo.classification.backbones import OTXMobileNetV3
from otx.algo.classification.classifier.base_classifier import ImageClassifier
Expand Down Expand Up @@ -158,14 +158,12 @@ def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassCls
feature_vector=outputs["feature_vector"],
)

def _reset_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.__orig_model_forward = self.model.forward
self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")

def _restore_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.model.forward = self.__orig_model_forward # type: ignore[method-assign]
return self.model(images=image, mode="tensor")


class MobileNetV3ForMultilabelCls(OTXMultilabelClsModel):
Expand Down Expand Up @@ -277,14 +275,12 @@ def forward_explain(self, inputs: MultilabelClsBatchDataEntity) -> MultilabelCls
feature_vector=outputs["feature_vector"],
)

def _reset_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.__orig_model_forward = self.model.forward
self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")

def _restore_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.model.forward = self.__orig_model_forward # type: ignore[method-assign]
return self.model(images=image, mode="tensor")


class MobileNetV3ForHLabelCls(OTXHlabelClsModel):
Expand Down Expand Up @@ -423,11 +419,9 @@ def forward_explain(self, inputs: HlabelClsBatchDataEntity) -> HlabelClsBatchPre
feature_vector=outputs["feature_vector"],
)

def _reset_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.__orig_model_forward = self.model.forward
self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")

def _restore_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.model.forward = self.__orig_model_forward # type: ignore[method-assign]
return self.model(images=image, mode="tensor")
6 changes: 5 additions & 1 deletion src/otx/algo/classification/otx_dino_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, Any

import torch
from torch import nn
from torch import Tensor, nn

from otx.algo.utils.mmconfig import read_mmconfig
from otx.core.data.entity.base import OTXBatchLossEntity
Expand Down Expand Up @@ -162,3 +162,7 @@ def _exporter(self) -> OTXModelExporter:
def _optimization_config(self) -> dict[str, Any]:
"""PTQ config for DinoV2Cls."""
return {"model_type": "transformer"}

def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
return self.model(image)
14 changes: 6 additions & 8 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import TYPE_CHECKING, Any, Callable, Literal

import torch
from torch import nn
from torch import Tensor, nn
from torchvision.models import get_model, get_model_weights

from otx.algo.explain.explain_algo import ReciproCAM
Expand Down Expand Up @@ -312,11 +312,9 @@ def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassCls
feature_vector=outputs["feature_vector"],
)

def _reset_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.__orig_model_forward = self.model.forward
self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001
def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
return self.model(images=image, mode="explain")

def _restore_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.model.forward = self.__orig_model_forward # type: ignore[method-assign]
return self.model(images=image, mode="tensor")
29 changes: 15 additions & 14 deletions src/otx/algo/detection/atss.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,21 @@ def _exporter(self) -> OTXModelExporter:

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.atss",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard",
pad_value=0,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
with self.export_model_forward_context():
return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.atss",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard",
pad_value=0,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
Expand Down
29 changes: 15 additions & 14 deletions src/otx/algo/detection/rtmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,21 @@ def _exporter(self) -> OTXModelExporter:

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.rtmdet",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
with self.export_model_forward_context():
return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.rtmdet",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
Expand Down
29 changes: 15 additions & 14 deletions src/otx/algo/detection/ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,21 @@ def _exporter(self) -> OTXModelExporter:

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.ssd_mobilenetv2",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard",
pad_value=0,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
with self.export_model_forward_context():
return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.ssd_mobilenetv2",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard",
pad_value=0,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)

def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
"""Callback on load checkpoint."""
Expand Down
58 changes: 30 additions & 28 deletions src/otx/algo/detection/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,21 @@ def _exporter(self) -> OTXModelExporter:

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.yolox",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=True,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
with self.export_model_forward_context():
return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.yolox",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=True,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
Expand Down Expand Up @@ -109,17 +110,18 @@ def _exporter(self) -> OTXModelExporter:

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.yolox_tiny",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
with self.export_model_forward_context():
return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.detection.mmdeploy.yolox_tiny",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
58 changes: 30 additions & 28 deletions src/otx/algo/instance_segmentation/maskrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,20 +57,21 @@ def _exporter(self) -> OTXModelExporter:

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving
pad_value=0,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
with self.export_model_forward_context():
return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving
pad_value=0,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)

def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model.") -> dict:
"""Load the previous OTX ckpt according to OTX2.0."""
Expand Down Expand Up @@ -109,17 +110,18 @@ def _exporter(self) -> OTXModelExporter:

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn_swint",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving
pad_value=0,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
with self.export_model_forward_context():
return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.maskrcnn_swint",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="standard", # [TODO](@Eunwoo): need to revert it to fit_to_window after resolving
pad_value=0,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
29 changes: 15 additions & 14 deletions src/otx/algo/instance_segmentation/rtmdet_inst.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,18 @@ def _exporter(self) -> OTXModelExporter:

mean, std = get_mean_std_from_data_processing(self.config)

return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.rtmdet_inst",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
with self.export_model_forward_context():
return MMdeployExporter(
model_builder=self._create_model,
model_cfg=deepcopy(self.config),
deploy_cfg="otx.algo.instance_segmentation.mmdeploy.rtmdet_inst",
test_pipeline=self._make_fake_test_pipeline(),
task_level_export_parameters=self._export_parameters,
input_size=self.image_size,
mean=mean,
std=std,
resize_mode="fit_to_window_letterbox",
pad_value=114,
swap_rgb=False,
output_names=["feature_vector", "saliency_map"] if self.explain_mode else None,
)
Loading

0 comments on commit 2e6d225

Please sign in to comment.