diff --git a/src/otx/algo/classification/deit_tiny.py b/src/otx/algo/classification/deit_tiny.py index 6efd8e8ec36..653a724a0f2 100644 --- a/src/otx/algo/classification/deit_tiny.py +++ b/src/otx/algo/classification/deit_tiny.py @@ -16,23 +16,23 @@ from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import ( - ExplainableOTXClsModel, MMPretrainHlabelClsModel, MMPretrainMulticlassClsModel, MMPretrainMultilabelClsModel, ) +from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - from mmpretrain.models import ImageClassifier + from mmpretrain.models.classifiers import ImageClassifier from mmpretrain.structures import DataSample from otx.core.metrics import MetricCallable -class ExplainableDeit(ExplainableOTXClsModel): +class ForwardExplainMixInForDeit(ExplainableMixInMMPretrainModel): """Deit model which can attach a XAI hook.""" @torch.no_grad() @@ -145,7 +145,7 @@ def _optimization_config(self) -> dict[str, Any]: return {"model_type": "transformer"} -class DeitTinyForHLabelCls(ExplainableDeit, MMPretrainHlabelClsModel): +class DeitTinyForHLabelCls(ForwardExplainMixInForDeit, MMPretrainHlabelClsModel): """DeitTiny Model for hierarchical label classification task.""" def __init__( @@ -172,7 +172,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix) -class DeitTinyForMulticlassCls(ExplainableDeit, MMPretrainMulticlassClsModel): +class DeitTinyForMulticlassCls(ForwardExplainMixInForDeit, MMPretrainMulticlassClsModel): """DeitTiny Model for multi-label classification task.""" def __init__( @@ -198,7 +198,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix) -class DeitTinyForMultilabelCls(ExplainableDeit, MMPretrainMultilabelClsModel): +class DeitTinyForMultilabelCls(ForwardExplainMixInForDeit, MMPretrainMultilabelClsModel): """DeitTiny Model for multi-class classification task.""" def __init__( diff --git a/src/otx/algo/classification/efficientnet_b0.py b/src/otx/algo/classification/efficientnet_b0.py index 5a5dd49af8b..c0c4ba212c4 100644 --- a/src/otx/algo/classification/efficientnet_b0.py +++ b/src/otx/algo/classification/efficientnet_b0.py @@ -15,6 +15,7 @@ MMPretrainMulticlassClsModel, MMPretrainMultilabelClsModel, ) +from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo @@ -24,7 +25,7 @@ from otx.core.metrics import MetricCallable -class EfficientNetB0ForHLabelCls(MMPretrainHlabelClsModel): +class EfficientNetB0ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlabelClsModel): """EfficientNetB0 Model for hierarchical label classification task.""" def __init__( @@ -51,7 +52,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "hlabel", add_prefix) -class EfficientNetB0ForMulticlassCls(MMPretrainMulticlassClsModel): +class EfficientNetB0ForMulticlassCls(ExplainableMixInMMPretrainModel, MMPretrainMulticlassClsModel): """EfficientNetB0 Model for multi-label classification task.""" def __init__( @@ -79,7 +80,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix) -class EfficientNetB0ForMultilabelCls(MMPretrainMultilabelClsModel): +class EfficientNetB0ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel): """EfficientNetB0 Model for multi-class classification task.""" def __init__( diff --git a/src/otx/algo/classification/efficientnet_v2.py b/src/otx/algo/classification/efficientnet_v2.py index 8b2bbf965d4..97632f42f44 100644 --- a/src/otx/algo/classification/efficientnet_v2.py +++ b/src/otx/algo/classification/efficientnet_v2.py @@ -15,6 +15,7 @@ MMPretrainMulticlassClsModel, MMPretrainMultilabelClsModel, ) +from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo @@ -24,7 +25,7 @@ from otx.core.metrics import MetricCallable -class EfficientNetV2ForHLabelCls(MMPretrainHlabelClsModel): +class EfficientNetV2ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlabelClsModel): """EfficientNetV2 Model for hierarchical label classification task.""" def __init__( @@ -51,7 +52,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix) -class EfficientNetV2ForMulticlassCls(MMPretrainMulticlassClsModel): +class EfficientNetV2ForMulticlassCls(ExplainableMixInMMPretrainModel, MMPretrainMulticlassClsModel): """EfficientNetV2 Model for multi-label classification task.""" def __init__( @@ -79,7 +80,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix) -class EfficientNetV2ForMultilabelCls(MMPretrainMultilabelClsModel): +class EfficientNetV2ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel): """EfficientNetV2 Model for multi-class classification task.""" def __init__( diff --git a/src/otx/algo/classification/mobilenet_v3_large.py b/src/otx/algo/classification/mobilenet_v3_large.py index 47253fcd73c..7ef27adfe23 100644 --- a/src/otx/algo/classification/mobilenet_v3_large.py +++ b/src/otx/algo/classification/mobilenet_v3_large.py @@ -15,6 +15,7 @@ MMPretrainMulticlassClsModel, MMPretrainMultilabelClsModel, ) +from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel from otx.core.schedulers import LRSchedulerListCallable from otx.core.types.label import HLabelInfo @@ -24,7 +25,7 @@ from otx.core.metrics import MetricCallable -class MobileNetV3ForHLabelCls(MMPretrainHlabelClsModel): +class MobileNetV3ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlabelClsModel): """MobileNetV3 Model for hierarchical label classification task.""" def __init__( @@ -58,7 +59,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "hlabel", add_prefix) -class MobileNetV3ForMulticlassCls(MMPretrainMulticlassClsModel): +class MobileNetV3ForMulticlassCls(ExplainableMixInMMPretrainModel, MMPretrainMulticlassClsModel): """MobileNetV3 Model for multi-label classification task.""" def __init__( @@ -93,7 +94,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multiclass", add_prefix) -class MobileNetV3ForMultilabelCls(MMPretrainMultilabelClsModel): +class MobileNetV3ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel): """MobileNetV3 Model for multi-class classification task.""" def __init__( diff --git a/src/otx/algo/classification/torchvision_model.py b/src/otx/algo/classification/torchvision_model.py index 2331cdf106a..796160f3cb5 100644 --- a/src/otx/algo/classification/torchvision_model.py +++ b/src/otx/algo/classification/torchvision_model.py @@ -9,11 +9,13 @@ import torch from torch import nn -from torchvision import tv_tensors from torchvision.models import get_model, get_model_weights +from otx.algo.hooks.recording_forward_hook import ReciproCAMHook from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.classification import MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity +from otx.core.exporter.base import OTXModelExporter +from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics.accuracy import MultiClassClsMetricCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable from otx.core.model.classification import OTXMulticlassClsModel @@ -137,6 +139,12 @@ def __init__( self.softmax = nn.Softmax(dim=-1) self.loss = loss + self.explainer = ReciproCAMHook( + self._head_forward_fn, + num_classes=num_classes, + optimize_gap=True, + ) + def forward( self, images: torch.Tensor, @@ -161,8 +169,39 @@ def forward( return logits if mode == "loss": return self.loss(logits, labels) + if mode == "explain": + return self._forward_explain(images) + return self.softmax(logits) + def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]: + x = self.backbone(images) + backbone_feat = x + + saliency_map = self.explainer.func(backbone_feat) + + if len(x.shape) == 4 and not self.use_layer_norm_2d: + x = x.view(x.size(0), -1) + + feature_vector = x + + logits = self.head(x) + + return { + "logits": logits, + "preds": logits.argmax(-1, keepdim=False), + "scores": self.softmax(logits), + "saliency_map": saliency_map, + "feature_vector": feature_vector, + } + + @torch.no_grad() + def _head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: + """Performs model's neck and head forward.""" + if len(x.shape) == 4 and not self.use_layer_norm_2d: + x = x.view(x.size(0), -1) + return self.head(x) + class OTXTVModel(OTXMulticlassClsModel): """OTXTVModel is that represents a TorchVision model for multiclass classification. @@ -174,6 +213,8 @@ class OTXTVModel(OTXMulticlassClsModel): freeze_backbone (bool, optional): Whether to freeze the backbone model. Defaults to False. """ + model: TVModelWithLossComputation + def __init__( self, backbone: TVModelType, @@ -207,14 +248,17 @@ def _create_model(self) -> nn.Module: ) def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]: - if isinstance(inputs.images, list): - images = tv_tensors.wrap(torch.stack(inputs.images, dim=0), like=inputs.images[0]) + if self.training: + mode = "loss" + elif self.explain_mode: + mode = "explain" else: - images = inputs.images + mode = "predict" + return { - "images": images, + "images": inputs.stacked_images, "labels": torch.cat(inputs.labels, dim=0), - "mode": "loss" if self.training else "predict", + "mode": mode, } def _customize_outputs( @@ -230,23 +274,6 @@ def _customize_outputs( scores = torch.unbind(logits, 0) preds = logits.argmax(-1, keepdim=True).unbind(0) - if self.explain_mode: - if not isinstance(outputs, dict) or "saliency_map" not in outputs: - msg = "No saliency maps in the model output." - raise ValueError(msg) - - saliency_maps = outputs["saliency_map"].detach().cpu().numpy() - - return MulticlassClsBatchPredEntity( - batch_size=len(preds), - images=inputs.images, - imgs_info=inputs.imgs_info, - scores=scores, - labels=preds, - saliency_maps=list(saliency_maps), - feature_vectors=[], - ) - return MulticlassClsBatchPredEntity( batch_size=inputs.batch_size, images=inputs.images, @@ -255,6 +282,11 @@ def _customize_outputs( labels=preds, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter(**self._export_parameters) + @property def _export_parameters(self) -> dict[str, Any]: """Defines parameters required to export a particular model implementation.""" @@ -270,44 +302,28 @@ def _export_parameters(self) -> dict[str, Any]: parameters = super()._export_parameters parameters.update(export_params) - return parameters - - @staticmethod - def _forward_explain_image_classifier( - self: TVModelWithLossComputation, - images: torch.Tensor, - labels: torch.Tensor | None = None, # noqa: ARG004 - mode: str = "tensor", - ) -> dict: - """Forward func of the TVModelWithLossComputation instance.""" - x = self.backbone(images) - backbone_feat = x - - saliency_map = self.explain_fn(backbone_feat) - - if len(x.shape) == 4 and not self.use_layer_norm_2d: - x = x.view(x.size(0), -1) - feature_vector = x - if len(feature_vector.shape) == 1: - feature_vector = feature_vector.unsqueeze(0) + return parameters - logits = self.head(x) - if mode == "predict": - logits = self.softmax(logits) + def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity: + """Model forward explain function.""" + outputs = self.model(images=inputs.stacked_images, mode="explain") - return { - "logits": logits, - "feature_vector": feature_vector, - "saliency_map": saliency_map, - } + return MulticlassClsBatchPredEntity( + batch_size=len(outputs["preds"]), + images=inputs.images, + imgs_info=inputs.imgs_info, + labels=outputs["preds"], + scores=outputs["scores"], + saliency_map=outputs["saliency_map"], + feature_vector=outputs["feature_vector"], + ) - @torch.no_grad() - def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: - """Performs model's neck and head forward. Can be redefined at the model's level.""" - if (head := getattr(self.model, "head", None)) is None: - raise ValueError + def _reset_model_forward(self) -> None: + # TODO(vinnamkim): This will be revisited by the export refactoring + self.__orig_model_forward = self.model.forward + self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001 - if len(x.shape) == 4 and not self.model.use_layer_norm_2d: - x = x.view(x.size(0), -1) - return head(x) + def _restore_model_forward(self) -> None: + # TODO(vinnamkim): This will be revisited by the export refactoring + self.model.forward = self.__orig_model_forward # type: ignore[method-assign] diff --git a/src/otx/algo/hooks/recording_forward_hook.py b/src/otx/algo/hooks/recording_forward_hook.py index f1de5fbd311..ae51d69f8c5 100644 --- a/src/otx/algo/hooks/recording_forward_hook.py +++ b/src/otx/algo/hooks/recording_forward_hook.py @@ -5,28 +5,31 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Sequence +from typing import TYPE_CHECKING, Callable import numpy as np import torch +from otx.core.types.explain import FeatureMapType + if TYPE_CHECKING: from mmengine.structures.instance_data import InstanceData from torch.utils.hooks import RemovableHandle +HeadForwardFn = Callable[[FeatureMapType], torch.Tensor] +ExplainerForwardFn = HeadForwardFn + -def feature_vector_fn(feature_map: torch.Tensor | Sequence[torch.Tensor]) -> torch.Tensor: +def get_feature_vector(feature_map: FeatureMapType) -> torch.Tensor: """Generate the feature vector by average pooling feature maps.""" if isinstance(feature_map, (list, tuple)): # aggregate feature maps from Feature Pyramid Network - feature_vectors = [ + feature_vector = [ # Spatially pooling and flatten, B x C x H x W => B x C' torch.nn.functional.adaptive_avg_pool2d(f, (1, 1)).flatten(start_dim=1) for f in feature_map ] - if len(feature_vectors) > 1: - return torch.cat(feature_vectors, 1) - return feature_vectors[0] + return torch.cat(feature_vector, 1) return torch.nn.functional.adaptive_avg_pool2d(feature_map, (1, 1)).flatten(start_dim=1) @@ -38,7 +41,7 @@ class BaseRecordingForwardHook: normalize (bool): Whether to normalize the resulting saliency maps. """ - def __init__(self, head_forward_fn: Callable | None = None, normalize: bool = True) -> None: + def __init__(self, head_forward_fn: HeadForwardFn | None = None, normalize: bool = True) -> None: self._head_forward_fn = head_forward_fn self.handle: RemovableHandle | None = None self._records: list[torch.Tensor] = [] @@ -102,17 +105,15 @@ def _torch_to_numpy_from_list(self, tensor_list: list[torch.Tensor | None]) -> N tensor_list[i] = tensor.detach().cpu().numpy() @staticmethod - def _normalize_map(saliency_maps: torch.Tensor) -> torch.Tensor: + def _normalize_map(saliency_map: torch.Tensor) -> torch.Tensor: """Normalize saliency maps.""" - max_values, _ = torch.max(saliency_maps, -1) - min_values, _ = torch.min(saliency_maps, -1) - if len(saliency_maps.shape) == 2: - saliency_maps = 255 * (saliency_maps - min_values[:, None]) / (max_values - min_values + 1e-12)[:, None] + max_values, _ = torch.max(saliency_map, -1) + min_values, _ = torch.min(saliency_map, -1) + if len(saliency_map.shape) == 2: + saliency_map = 255 * (saliency_map - min_values[:, None]) / (max_values - min_values + 1e-12)[:, None] else: - saliency_maps = ( - 255 * (saliency_maps - min_values[:, :, None]) / (max_values - min_values + 1e-12)[:, :, None] - ) - return saliency_maps.to(torch.uint8) + saliency_map = 255 * (saliency_map - min_values[:, :, None]) / (max_values - min_values + 1e-12)[:, :, None] + return saliency_map.to(torch.uint8) class ActivationMapHook(BaseRecordingForwardHook): @@ -128,7 +129,7 @@ def create_and_register_hook( hook.handle = backbone.register_forward_hook(hook.recording_forward) return hook - def func(self, feature_map: torch.Tensor | Sequence[torch.Tensor], fpn_idx: int = -1) -> torch.Tensor: + def func(self, feature_map: FeatureMapType, fpn_idx: int = -1) -> torch.Tensor: """Generate the saliency map by average feature maps then normalizing to (0, 255).""" if isinstance(feature_map, (list, tuple)): feature_map = feature_map[fpn_idx] @@ -151,7 +152,7 @@ class ReciproCAMHook(BaseRecordingForwardHook): def __init__( self, - head_forward_fn: Callable, + head_forward_fn: HeadForwardFn, num_classes: int, normalize: bool = True, optimize_gap: bool = False, @@ -164,7 +165,7 @@ def __init__( def create_and_register_hook( cls, backbone: torch.nn.Module, - head_forward_fn: Callable, + head_forward_fn: HeadForwardFn, num_classes: int, optimize_gap: bool, ) -> BaseRecordingForwardHook: @@ -177,7 +178,7 @@ def create_and_register_hook( hook.handle = backbone.register_forward_hook(hook.recording_forward) return hook - def func(self, feature_map: torch.Tensor | Sequence[torch.Tensor], fpn_idx: int = -1) -> torch.Tensor: + def func(self, feature_map: FeatureMapType, fpn_idx: int = -1) -> torch.Tensor: """Generate the class-wise saliency maps using Recipro-CAM and then normalizing to (0, 255). Args: @@ -193,17 +194,17 @@ def func(self, feature_map: torch.Tensor | Sequence[torch.Tensor], fpn_idx: int feature_map = feature_map[fpn_idx] batch_size, channel, h, w = feature_map.size() - saliency_maps = torch.empty(batch_size, self._num_classes, h, w) + saliency_map = torch.empty(batch_size, self._num_classes, h, w) for f in range(batch_size): mosaic_feature_map = self._get_mosaic_feature_map(feature_map[f], channel, h, w) mosaic_prediction = self._predict_from_feature_map(mosaic_feature_map) - saliency_maps[f] = mosaic_prediction.transpose(0, 1).reshape((self._num_classes, h, w)) + saliency_map[f] = mosaic_prediction.transpose(0, 1).reshape((self._num_classes, h, w)) if self._norm_saliency_maps: - saliency_maps = saliency_maps.reshape((batch_size, self._num_classes, h * w)) - saliency_maps = self._normalize_map(saliency_maps) + saliency_map = saliency_map.reshape((batch_size, self._num_classes, h * w)) + saliency_map = self._normalize_map(saliency_map) - return saliency_maps.reshape((batch_size, self._num_classes, h, w)) + return saliency_map.reshape((batch_size, self._num_classes, h, w)) def _get_mosaic_feature_map(self, feature_map: torch.Tensor, c: int, h: int, w: int) -> torch.Tensor: if self._optimize_gap: @@ -239,7 +240,7 @@ class ViTReciproCAMHook(BaseRecordingForwardHook): def __init__( self, - head_forward_fn: Callable, + head_forward_fn: HeadForwardFn, num_classes: int, use_gaussian: bool = True, cls_token: bool = True, @@ -254,7 +255,7 @@ def __init__( def create_and_register_hook( cls, target_layernorm: torch.nn.Module, - head_forward_fn: Callable, + head_forward_fn: HeadForwardFn, num_classes: int, ) -> BaseRecordingForwardHook: """Create this object and register it to the module forward hook.""" @@ -276,16 +277,16 @@ def func(self, feature_map: torch.Tensor, _: int = -1) -> torch.Tensor: """ batch_size, token_number, _ = feature_map.size() h = w = int((token_number - 1) ** 0.5) - saliency_maps = torch.empty(batch_size, self._num_classes, h, w) + saliency_map = torch.empty(batch_size, self._num_classes, h, w) for i in range(batch_size): mosaic_feature_map = self._get_mosaic_feature_map(feature_map[i]) mosaic_prediction = self._predict_from_feature_map(mosaic_feature_map) - saliency_maps[i] = mosaic_prediction.transpose(1, 0).reshape((self._num_classes, h, w)) + saliency_map[i] = mosaic_prediction.transpose(1, 0).reshape((self._num_classes, h, w)) if self._norm_saliency_maps: - saliency_maps = saliency_maps.reshape((batch_size, self._num_classes, h * w)) - saliency_maps = self._normalize_map(saliency_maps) - return saliency_maps.reshape((batch_size, self._num_classes, h, w)) + saliency_map = saliency_map.reshape((batch_size, self._num_classes, h * w)) + saliency_map = self._normalize_map(saliency_map) + return saliency_map.reshape((batch_size, self._num_classes, h, w)) def _get_mosaic_feature_map(self, feature_map: torch.Tensor) -> torch.Tensor: token_number, dim = feature_map.size() @@ -346,13 +347,13 @@ def __init__( def func( self, - cls_scores: torch.Tensor | Sequence[torch.Tensor], + cls_scores: FeatureMapType, _: int = -1, ) -> torch.Tensor: """Generate the saliency map from raw classification head output, then normalizing to (0, 255). Args: - cls_scores (torch.Tensor | Sequence[torch.Tensor]): Classification scores from cls_head. + cls_scores (FeatureMapType): Classification scores from cls_head. Returns: torch.Tensor: Class-wise Saliency Maps. One saliency map per each class - [batch, class_id, H, W] @@ -360,7 +361,7 @@ def func( middle_idx = len(cls_scores) // 2 # Resize to the middle feature map batch_size, _, height, width = cls_scores[middle_idx].size() - saliency_maps = torch.empty(batch_size, self._num_classes, height, width) + saliency_map = torch.empty(batch_size, self._num_classes, height, width) for batch_idx in range(batch_size): cls_scores_anchorless = [] for scale_idx, cls_scores_per_scale in enumerate(cls_scores): @@ -377,18 +378,18 @@ def func( for cls_scores_anchorless_per_level in cls_scores_anchorless ] - saliency_maps[batch_idx] = torch.cat(cls_scores_anchorless_resized, dim=0).mean(dim=0) + saliency_map[batch_idx] = torch.cat(cls_scores_anchorless_resized, dim=0).mean(dim=0) # Don't use softmax for tiles in tiling detection, if the tile doesn't contain objects, # it would highlight one of the class maps as a background class if self.use_cls_softmax: - saliency_maps[0] = torch.stack([torch.softmax(t, dim=1) for t in saliency_maps[0]]) + saliency_map[0] = torch.stack([torch.softmax(t, dim=1) for t in saliency_map[0]]) if self._norm_saliency_maps: - saliency_maps = saliency_maps.reshape((batch_size, self._num_classes, -1)) - saliency_maps = self._normalize_map(saliency_maps) + saliency_map = saliency_map.reshape((batch_size, self._num_classes, -1)) + saliency_map = self._normalize_map(saliency_map) - return saliency_maps.reshape((batch_size, self._num_classes, height, width)) + return saliency_map.reshape((batch_size, self._num_classes, height, width)) class MaskRCNNRecordingForwardHook(BaseRecordingForwardHook): @@ -436,19 +437,19 @@ def average_and_normalize( masks, scores, labels = (pred.masks, pred.scores, pred.labels) _, height, width = masks.shape - saliency_maps = torch.zeros((num_classes, height, width), dtype=torch.float32, device=labels.device) + saliency_map = torch.zeros((num_classes, height, width), dtype=torch.float32, device=labels.device) class_objects = [0 for _ in range(num_classes)] for confidence, class_ind, raw_mask in zip(scores, labels, masks): weighted_mask = raw_mask * confidence - saliency_maps[class_ind] += weighted_mask + saliency_map[class_ind] += weighted_mask class_objects[class_ind] += 1 for class_ind in range(num_classes): # Normalize by number of objects of the certain class - saliency_maps[class_ind] /= max(class_objects[class_ind], 1) + saliency_map[class_ind] /= max(class_objects[class_ind], 1) - saliency_maps = saliency_maps.reshape((num_classes, -1)) - saliency_maps = cls._normalize_map(saliency_maps) + saliency_map = saliency_map.reshape((num_classes, -1)) + saliency_map = cls._normalize_map(saliency_map) - return saliency_maps.reshape(num_classes, height, width) + return saliency_map.reshape(num_classes, height, width) diff --git a/src/otx/algo/utils/xai_utils.py b/src/otx/algo/utils/xai_utils.py index 47bfe16b6b4..9718e498e57 100644 --- a/src/otx/algo/utils/xai_utils.py +++ b/src/otx/algo/utils/xai_utils.py @@ -2,6 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 """Utils used for XAI.""" +# TODO(gzalessk): Typings in this file is too weak or wrong. It should be fixed. +# For example, `pred_labels: list | None` has no object typing containered in the list. +# On the other hand, process_saliency_maps should not produce list of dictionaries +# (`list[dict[str, np.ndarray | torch.Tensor]]`). +# This is because the output will be assigned to OTXBatchPredEntity.saliency_map, +# but this has `list[np.ndarray | torch.Tensor]` typing, so that it makes a mismatch. + from __future__ import annotations from pathlib import Path @@ -9,10 +16,17 @@ import cv2 import numpy as np +import torch from datumaro import Image from otx.core.config.explain import ExplainConfig -from otx.core.data.entity.base import OTXBatchPredEntity +from otx.core.data.entity.classification import ( + HlabelClsBatchPredEntity, + MulticlassClsBatchPredEntity, + MultilabelClsBatchPredEntity, +) +from otx.core.data.entity.detection import DetBatchPredEntity +from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity from otx.core.types.explain import TargetExplainGroup if TYPE_CHECKING: @@ -20,40 +34,51 @@ from otx.core.data.module import OTXDataModule +ProcessedSaliencyMaps = list[dict[str, np.ndarray | torch.Tensor]] +OTXBatchPredEntitiesSupportXAI = ( + MulticlassClsBatchPredEntity + | MultilabelClsBatchPredEntity + | HlabelClsBatchPredEntity + | DetBatchPredEntity + | InstanceSegBatchPredEntity +) + def process_saliency_maps_in_pred_entity( - predict_result: list[OTXBatchPredEntity], + predict_result: list[OTXBatchPredEntitiesSupportXAI], explain_config: ExplainConfig, -) -> list[OTXBatchPredEntity]: +) -> list[OTXBatchPredEntitiesSupportXAI]: """Process saliency maps in PredEntity.""" - def _process(predict_result_per_batch: OTXBatchPredEntity) -> OTXBatchPredEntity: - saliency_maps = predict_result_per_batch.saliency_maps + def _process(predict_result_per_batch: OTXBatchPredEntitiesSupportXAI) -> OTXBatchPredEntitiesSupportXAI: + saliency_map: list[np.ndarray] = [ + saliency_map.cpu().numpy() if isinstance(saliency_map, torch.Tensor) else saliency_map + for saliency_map in predict_result_per_batch.saliency_map + ] imgs_info = predict_result_per_batch.imgs_info ori_img_shapes = [img_info.ori_shape for img_info in imgs_info] - if pred_labels := getattr(predict_result_per_batch, "labels", None): - pred_labels = [pred.tolist() for pred in pred_labels] + pred_labels = [pred.tolist() for pred in predict_result_per_batch.labels] - processed_saliency_maps = process_saliency_maps(saliency_maps, explain_config, pred_labels, ori_img_shapes) + processed_saliency_maps = process_saliency_maps(saliency_map, explain_config, pred_labels, ori_img_shapes) - return predict_result_per_batch.wrap(saliency_maps=processed_saliency_maps) + return predict_result_per_batch.wrap(saliency_map=processed_saliency_maps) return [_process(predict_result_per_batch) for predict_result_per_batch in predict_result] def process_saliency_maps( - saliency_maps: list, + saliency_map: list[np.ndarray], explain_config: ExplainConfig, pred_labels: list | None, ori_img_shapes: list, -) -> list[dict[Any, Any]]: +) -> ProcessedSaliencyMaps: """Perform saliency map convertion to dict and post-processing.""" if explain_config.target_explain_group == TargetExplainGroup.ALL: - processed_saliency_maps = convert_maps_to_dict_all(saliency_maps) + processed_saliency_maps = convert_maps_to_dict_all(saliency_map) elif explain_config.target_explain_group == TargetExplainGroup.PREDICTIONS: - processed_saliency_maps = convert_maps_to_dict_predictions(saliency_maps, pred_labels) + processed_saliency_maps = convert_maps_to_dict_predictions(saliency_map, pred_labels) elif explain_config.target_explain_group == TargetExplainGroup.IMAGE: - processed_saliency_maps = convert_maps_to_dict_image(saliency_maps) + processed_saliency_maps = convert_maps_to_dict_image(saliency_map) else: msg = f"Target explain group {explain_config.target_explain_group} is not supported." raise ValueError(msg) @@ -67,40 +92,43 @@ def process_saliency_maps( return processed_saliency_maps -def convert_maps_to_dict_all(saliency_maps: np.array) -> list[dict[Any, np.array]]: +def convert_maps_to_dict_all(saliency_map: list[np.ndarray]) -> list[dict[Any, np.array]]: """Convert salincy maps to dict for TargetExplainGroup.ALL.""" - if saliency_maps[0].ndim != 3: - msg = "Shape mismatch." - raise ValueError(msg) - processed_saliency_maps = [] - for maps_per_image in saliency_maps: + for maps_per_image in saliency_map: + if maps_per_image.ndim != 3: + msg = "Shape mismatch." + raise ValueError(msg) + explain_target_to_sal_map = dict(enumerate(maps_per_image)) processed_saliency_maps.append(explain_target_to_sal_map) return processed_saliency_maps -def convert_maps_to_dict_predictions(saliency_maps: np.array, pred_labels: list | None) -> list[dict[Any, np.array]]: +def convert_maps_to_dict_predictions( + saliency_map: list[np.ndarray], + pred_labels: list | None, +) -> list[dict[Any, np.array]]: """Convert salincy maps to dict for TargetExplainGroup.PREDICTIONS.""" - if saliency_maps[0].ndim != 3: + if saliency_map[0].ndim != 3: msg = "Shape mismatch." raise ValueError(msg) if not pred_labels: return [] processed_saliency_maps = [] - for i, maps_per_image in enumerate(saliency_maps): + for i, maps_per_image in enumerate(saliency_map): explain_target_to_sal_map = {label: maps_per_image[label] for label in pred_labels[i] if pred_labels[i]} processed_saliency_maps.append(explain_target_to_sal_map) return processed_saliency_maps -def convert_maps_to_dict_image(saliency_maps: np.array) -> list[dict[Any, np.array]]: +def convert_maps_to_dict_image(saliency_map: list[np.ndarray]) -> list[dict[Any, np.array]]: """Convert salincy maps to dict for TargetExplainGroup.IMAGE.""" - if saliency_maps[0].ndim != 2: + if saliency_map[0].ndim != 2: msg = "Shape mismatch." raise ValueError(msg) - return [{"map_per_image": map_per_image} for map_per_image in saliency_maps] + return [{"map_per_image": map_per_image} for map_per_image in saliency_map] def postprocess(saliency_map: np.ndarray, output_size: tuple[int, int] | None) -> np.ndarray: @@ -116,24 +144,24 @@ def postprocess(saliency_map: np.ndarray, output_size: tuple[int, int] | None) - def dump_saliency_maps( - predict_result: list[OTXBatchPredEntity], + predict_result: list[OTXBatchPredEntitiesSupportXAI], explain_config: ExplainConfig, datamodule: EVAL_DATALOADERS | OTXDataModule, output_dir: Path, weight: float = 0.3, ) -> None: """Sumps saliency maps (raw and with overlay).""" - output_dir = output_dir / "saliency_maps" + output_dir = output_dir / "saliency_map" output_dir.mkdir(parents=True, exist_ok=True) for predict_result_per_batch in predict_result: - saliency_maps = predict_result_per_batch.saliency_maps + saliency_map = predict_result_per_batch.saliency_map imgs_info = predict_result_per_batch.imgs_info - for pred_index in range(len(saliency_maps)): + for pred_index in range(len(saliency_map)): img_id = imgs_info[pred_index].img_idx img_data, image_save_name = _get_image_data_name(datamodule, img_id) - for class_id, s_map in saliency_maps[pred_index].items(): + for class_id, s_map in saliency_map[pred_index].items(): file_name_map = Path(image_save_name + "_class_" + str(class_id) + "_saliency_map.png") save_path_map = output_dir / file_name_map cv2.imwrite(str(save_path_map), s_map) diff --git a/src/otx/core/data/entity/base.py b/src/otx/core/data/entity/base.py index 5373d249d64..a1603d89134 100644 --- a/src/otx/core/data/entity/base.py +++ b/src/otx/core/data/entity/base.py @@ -669,25 +669,29 @@ class OTXBatchPredEntity(OTXBatchDataEntity): Attributes: scores: List of probability scores representing model predictions. - saliency_maps: List of saliency maps used to explain model predictions. + saliency_map: List of saliency maps used to explain model predictions. This field is optional and will be an empty list for non-XAI pipelines. - feature_vectors: List of intermediate feature vectors used for model predictions. + feature_vector: List of intermediate feature vectors used for model predictions. This field is optional and will be an empty list for non-XAI pipelines. """ scores: list[np.ndarray] | list[Tensor] # (Optional) XAI-related outputs - saliency_maps: list[np.ndarray] | list[Tensor] = field(default_factory=list) - feature_vectors: list[np.ndarray] | list[Tensor] = field(default_factory=list) + # TODO(vinnamkim): These are actually plural, but their namings are not + # This is because ModelAPI requires the OV IR to produce `saliency_map` + # and `feature_vector` (singular) named outputs. + # It should be fixed later. + saliency_map: list[np.ndarray] | list[Tensor] = field(default_factory=list) + feature_vector: list[np.ndarray] | list[Tensor] = field(default_factory=list) @property def has_xai_outputs(self) -> bool: """If the XAI related fields are fulfilled, return True.""" # NOTE: Don't know why but some of test cases in tests/integration/api/test_xai.py - # produce `len(self.saliency_maps) > 0` and `len(self.feature_vectors) == 0` - # return len(self.saliency_maps) > 0 and len(self.feature_vectors) > 0 - return len(self.saliency_maps) > 0 + # produce `len(self.saliency_map) > 0` and `len(self.feature_vector) == 0` + # return len(self.saliency_map) > 0 and len(self.feature_vector) > 0 + return len(self.saliency_map) > 0 class OTXBatchLossEntity(Dict[str, Tensor]): diff --git a/src/otx/core/exporter/base.py b/src/otx/core/exporter/base.py index b29946c78b4..6982afbe929 100644 --- a/src/otx/core/exporter/base.py +++ b/src/otx/core/exporter/base.py @@ -241,16 +241,35 @@ def _extend_model_metadata(self, metadata: dict[tuple[str, str], str]) -> dict[t return extra_data def _postprocess_openvino_model(self, exported_model: openvino.Model) -> openvino.Model: - if self.output_names is not None: - if len(self.output_names) != len(exported_model.outputs): - msg = "The number of outputs in the exported model doesn't match with exporter parameters" - raise RuntimeError(msg) - for i, name in enumerate(self.output_names): - exported_model.outputs[i].tensor.set_names({name}) - elif len(exported_model.outputs) == 1 and len(exported_model.outputs[0].get_names()) == 0: + if len(exported_model.outputs) == 1 and len(exported_model.outputs[0].get_names()) == 0: # workaround for OVC's bug: single output doesn't have a name in OV model exported_model.outputs[0].tensor.set_names({"output1"}) + if self.output_names is not None: + traced_outputs = [(output.get_names(), output) for output in exported_model.outputs] + + for output_name in self.output_names: + found = False + for name, output in traced_outputs: + # TODO(vinnamkim): This is because `name` in `traced_outputs` is a list of set such as + # [{'logits', '1555'}, {'1556', 'preds'}, {'1557', 'scores'}, + # {'saliency_map', '1551'}, {'feature_vector', '1554', 'input.1767'}] + # This ugly format of `name` comes from `openvino.convert_model`. + # Find a cleaner way for this in the future. + if output_name in name: + found = True + # NOTE: This is because without this renaming such as + # `{'saliency_map', '1551'}` => `{'saliency_map'}` + # ModelAPI cannot produce the outputs correctly. + output.tensor.set_names({output_name}) + + if not found: + msg = ( + "Given output name to export is not in the traced_outputs, " + f"{output_name} not in {traced_outputs}" + ) + raise RuntimeError(msg) + if self.metadata is not None: export_metadata = self._extend_model_metadata(self.metadata) exported_model = self._embed_openvino_ir_metadata(exported_model, export_metadata) diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index 84f8cd15de8..d6792052bde 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -80,6 +80,9 @@ class OTXModel(LightningModule, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEnti Args: num_classes: Number of classes this model can predict. + + Attributes: + explain_mode: If true, `self.predict_step()` will produce a XAI output as well """ _OPTIMIZED_MODEL_BASE_NAME: str = "optimized_model" @@ -97,7 +100,7 @@ def __init__( self._label_info = LabelInfo.from_num_classes(num_classes) if num_classes > 0 else NullLabelInfo() self.classification_layers: dict[str, dict[str, Any]] = {} self.model = self._create_model() - self.original_model_forward = None + self._explain_mode = False self.optimizer_callable = ensure_callable(optimizer) self.scheduler_callable = ensure_callable(scheduler) @@ -474,9 +477,11 @@ def get_explain_fn(self) -> Callable: raise NotImplementedError def _reset_model_forward(self) -> None: + # TODO(vinnamkim): This will be revisited by the export refactoring pass def _restore_model_forward(self) -> None: + # TODO(vinnamkim): This will be revisited by the export refactoring pass def forward_tiles( @@ -606,7 +611,7 @@ def _export_parameters(self) -> dict[str, Any]: Returns: dict[str, Any]: parameters of exporter. """ - parameters = {} + parameters: dict[str, Any] = {} all_labels = "" all_label_ids = "" for lbl in self.label_info.label_names: @@ -622,6 +627,9 @@ def _export_parameters(self) -> dict[str, Any]: ("model_info", "label_info"): self.label_info.to_json(), } + if self.explain_mode: + parameters["output_names"] = ["logits", "feature_vector", "saliency_map"] + return parameters def _reset_prediction_layer(self, num_classes: int) -> None: diff --git a/src/otx/core/model/classification.py b/src/otx/core/model/classification.py index 7a1fd57c64e..e9f3c36e70b 100644 --- a/src/otx/core/model/classification.py +++ b/src/otx/core/model/classification.py @@ -6,18 +6,13 @@ from __future__ import annotations import json -import types -from typing import TYPE_CHECKING, Any, Callable +from typing import TYPE_CHECKING, Any import numpy as np import torch from torchmetrics import Accuracy -from otx.core.data.entity.base import ( - OTXBatchLossEntity, - T_OTXBatchDataEntity, - T_OTXBatchPredEntity, -) +from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.classification import ( HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, @@ -43,9 +38,7 @@ if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable - from mmpretrain.models import ImageClassifier from mmpretrain.models.utils import ClsDataPreprocessor - from mmpretrain.structures import DataSample from omegaconf import DictConfig from openvino.model_api.models.utils import ClassificationResult from torch import nn @@ -53,131 +46,8 @@ from otx.core.metrics import MetricCallable -class ExplainableOTXClsModel( - OTXModel[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXTileBatchDataEntity], -): - """OTX classification model which can attach a XAI hook.""" - - @property - def has_gap(self) -> bool: - """Defines if GAP is used right after backbone. Can be redefined at the model's level.""" - return True - - @property - def _export_parameters(self) -> dict[str, Any]: - """Defines parameters required to export a particular model implementation.""" - export_params = super()._export_parameters - export_params["output_names"] = ["logits", "feature_vector", "saliency_map"] if self.explain_mode else None - return export_params - - @torch.no_grad() - def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: - """Performs model's neck and head forward. Can be redefined at the model's level.""" - if (neck := getattr(self.model, "neck", None)) is None: - raise ValueError - if (head := getattr(self.model, "head", None)) is None: - raise ValueError - - output = neck(x) - return head([output]) - - def forward_explain(self, inputs: T_OTXBatchDataEntity) -> T_OTXBatchPredEntity: - """Model forward function.""" - from otx.algo.hooks.recording_forward_hook import feature_vector_fn - - self.model.feature_vector_fn = feature_vector_fn - self.model.explain_fn = self.get_explain_fn() - - # If customize_inputs is overridden - outputs = ( - self._forward_explain_image_classifier(self.model, **self._customize_inputs(inputs)) - if self._customize_inputs != ExplainableOTXClsModel._customize_inputs - else self._forward_explain_image_classifier(self.model, inputs) - ) - - return ( - self._customize_outputs(outputs, inputs) - if self._customize_outputs != ExplainableOTXClsModel._customize_outputs - else outputs["predictions"] - ) - - @staticmethod - def _forward_explain_image_classifier( - self: ImageClassifier, - inputs: torch.Tensor, - data_samples: list[DataSample] | None = None, - mode: str = "tensor", - ) -> dict[str, torch.Tensor]: - """Forward func of the ImageClassifier instance, which located in ExplainableOTXClsModel().model.""" - x = self.backbone(inputs) - backbone_feat = x - - feature_vector = self.feature_vector_fn(backbone_feat) - saliency_map = self.explain_fn(backbone_feat) - - if self.with_neck: - x = self.neck(x) - - if mode == "tensor": - logits = self.head(x) if self.with_head else x - elif mode == "predict": - logits = self.head.predict(x, data_samples) - else: - msg = f'Invalid mode "{mode}".' - raise RuntimeError(msg) - - return { - "logits": logits, - "feature_vector": feature_vector, - "saliency_map": saliency_map, - } - - def get_explain_fn(self) -> Callable: - """Returns explain function.""" - from otx.algo.hooks.recording_forward_hook import ReciproCAMHook - - explainer = ReciproCAMHook( - self.head_forward_fn, - num_classes=self.num_classes, - optimize_gap=self.has_gap, - ) - return explainer.func - - def _reset_model_forward(self) -> None: - from otx.algo.hooks.recording_forward_hook import feature_vector_fn - - if not self.explain_mode: - return - - self.model.feature_vector_fn = feature_vector_fn - self.model.explain_fn = self.get_explain_fn() - forward_with_explain = self._forward_explain_image_classifier - - self.original_model_forward = self.model.forward - - func_type = types.MethodType - self.model.forward = func_type(forward_with_explain, self.model) - - def _restore_model_forward(self) -> None: - if not self.explain_mode: - return - - if not self.original_model_forward: - msg = "Original model forward was not saved." - raise RuntimeError(msg) - - func_type = types.MethodType - self.model.forward = func_type(self.original_model_forward, self.model) - self.original_model_forward = None - - @property - def _exporter(self) -> OTXModelExporter: - """Creates OTXModelExporter object that can export the model.""" - return OTXNativeModelExporter(**self._export_parameters) - - class OTXMulticlassClsModel( - ExplainableOTXClsModel[ + OTXModel[ MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity, T_OTXTileBatchDataEntity, @@ -332,8 +202,8 @@ def _customize_outputs( msg = "No saliency maps in the model output." raise ValueError(msg) - feature_vectors = outputs["feature_vector"].detach().cpu().numpy() - saliency_maps = outputs["saliency_map"].detach().cpu().numpy() + feature_vector = outputs["feature_vector"].detach() + saliency_map = outputs["saliency_map"].detach() return MulticlassClsBatchPredEntity( batch_size=len(predictions), @@ -341,8 +211,8 @@ def _customize_outputs( imgs_info=inputs.imgs_info, scores=scores, labels=labels, - feature_vectors=list(feature_vectors), - saliency_maps=list(saliency_maps), + feature_vector=list(feature_vector), + saliency_map=list(saliency_map), ) return MulticlassClsBatchPredEntity( @@ -353,6 +223,11 @@ def _customize_outputs( labels=labels, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter(**self._export_parameters) + @property def _export_parameters(self) -> dict[str, Any]: """Defines parameters required to export a particular model implementation.""" @@ -373,7 +248,7 @@ def _export_parameters(self) -> dict[str, Any]: class OTXMultilabelClsModel( - ExplainableOTXClsModel[ + OTXModel[ MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, T_OTXTileBatchDataEntity, @@ -529,8 +404,8 @@ def _customize_outputs( msg = "No saliency maps in the model output." raise ValueError(msg) - feature_vectors = outputs["feature_vector"].detach().cpu().numpy() - saliency_maps = outputs["saliency_map"].detach().cpu().numpy() + feature_vector = outputs["feature_vector"].detach() + saliency_map = outputs["saliency_map"].detach() return MultilabelClsBatchPredEntity( batch_size=len(predictions), @@ -538,8 +413,8 @@ def _customize_outputs( imgs_info=inputs.imgs_info, scores=scores, labels=labels, - feature_vectors=list(feature_vectors), - saliency_maps=list(saliency_maps), + feature_vector=list(feature_vector), + saliency_map=list(saliency_map), ) return MultilabelClsBatchPredEntity( @@ -550,6 +425,11 @@ def _customize_outputs( labels=labels, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter(**self._export_parameters) + @property def _export_parameters(self) -> dict[str, Any]: """Defines parameters required to export a particular model implementation.""" @@ -566,7 +446,7 @@ def _export_parameters(self) -> dict[str, Any]: class OTXHlabelClsModel( - ExplainableOTXClsModel[ + OTXModel[ HlabelClsBatchDataEntity, HlabelClsBatchPredEntity, T_OTXTileBatchDataEntity, @@ -749,8 +629,8 @@ def _customize_outputs( msg = "No saliency maps in the model output." raise ValueError(msg) - feature_vectors = outputs["feature_vector"].detach().cpu().numpy() - saliency_maps = outputs["saliency_map"].detach().cpu().numpy() + feature_vector = outputs["feature_vector"].detach() + saliency_map = outputs["saliency_map"].detach() return HlabelClsBatchPredEntity( batch_size=len(outputs), @@ -758,8 +638,8 @@ def _customize_outputs( imgs_info=inputs.imgs_info, scores=scores, labels=labels, - feature_vectors=list(feature_vectors), - saliency_maps=list(saliency_maps), + feature_vector=list(feature_vector), + saliency_map=list(saliency_map), ) return HlabelClsBatchPredEntity( @@ -770,6 +650,11 @@ def _customize_outputs( labels=labels, ) + @property + def _exporter(self) -> OTXModelExporter: + """Creates OTXModelExporter object that can export the model.""" + return OTXNativeModelExporter(**self._export_parameters) + @property def _export_parameters(self) -> dict[str, Any]: """Defines parameters required to export a particular model implementation.""" @@ -835,8 +720,8 @@ def _customize_outputs( imgs_info=inputs.imgs_info, scores=pred_scores, labels=pred_labels, - saliency_maps=predicted_s_maps, - feature_vectors=predicted_f_vectors, + saliency_map=predicted_s_maps, + feature_vector=predicted_f_vectors, ) return MulticlassClsBatchPredEntity( @@ -909,8 +794,8 @@ def _customize_outputs( imgs_info=inputs.imgs_info, scores=pred_scores, labels=[], - saliency_maps=predicted_s_maps, - feature_vectors=predicted_f_vectors, + saliency_map=predicted_s_maps, + feature_vector=predicted_f_vectors, ) return MultilabelClsBatchPredEntity( @@ -1007,8 +892,8 @@ def _customize_outputs( imgs_info=inputs.imgs_info, scores=all_pred_scores, labels=all_pred_labels, - saliency_maps=predicted_s_maps, - feature_vectors=predicted_f_vectors, + saliency_map=predicted_s_maps, + feature_vector=predicted_f_vectors, ) return HlabelClsBatchPredEntity( diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index e0e2fbefd8c..4639c4ef026 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -190,9 +190,9 @@ def forward_explain( inputs: DetBatchDataEntity, ) -> DetBatchPredEntity: """Model forward function.""" - from otx.algo.hooks.recording_forward_hook import feature_vector_fn + from otx.algo.hooks.recording_forward_hook import get_feature_vector - self.model.feature_vector_fn = feature_vector_fn + self.model.feature_vector_fn = get_feature_vector self.model.explain_fn = self.get_explain_fn() # If customize_inputs is overridden @@ -460,8 +460,8 @@ def _customize_outputs( msg = "No saliency maps in the model output." raise ValueError(msg) - saliency_maps = outputs["saliency_map"].detach().cpu().numpy() - feature_vectors = outputs["feature_vector"].detach().cpu().numpy() + saliency_map = outputs["saliency_map"].detach().cpu().numpy() + feature_vector = outputs["feature_vector"].detach().cpu().numpy() return DetBatchPredEntity( batch_size=len(predictions), @@ -470,8 +470,8 @@ def _customize_outputs( scores=scores, bboxes=bboxes, labels=labels, - saliency_maps=saliency_maps, - feature_vectors=feature_vectors, + saliency_map=saliency_map, + feature_vector=feature_vector, ) return DetBatchPredEntity( @@ -611,8 +611,8 @@ def _customize_outputs( scores=scores, bboxes=bboxes, labels=labels, - saliency_maps=predicted_s_maps, - feature_vectors=predicted_f_vectors, + saliency_map=predicted_s_maps, + feature_vector=predicted_f_vectors, ) return DetBatchPredEntity( diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index caa45852ae4..572ba044af8 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -240,9 +240,9 @@ def forward_explain( inputs: InstanceSegBatchDataEntity, ) -> InstanceSegBatchPredEntity: """Model forward function.""" - from otx.algo.hooks.recording_forward_hook import feature_vector_fn + from otx.algo.hooks.recording_forward_hook import get_feature_vector - self.model.feature_vector_fn = feature_vector_fn + self.model.feature_vector_fn = get_feature_vector self.model.explain_fn = self.get_explain_fn() # If customize_inputs is overridden @@ -504,8 +504,8 @@ def _customize_outputs( msg = "No saliency maps in the model output." raise ValueError(msg) - saliency_maps = outputs["saliency_map"].detach().cpu().numpy() - feature_vectors = outputs["feature_vector"].detach().cpu().numpy() + saliency_map = outputs["saliency_map"].detach().cpu().numpy() + feature_vector = outputs["feature_vector"].detach().cpu().numpy() return InstanceSegBatchPredEntity( batch_size=len(predictions), @@ -516,8 +516,8 @@ def _customize_outputs( masks=masks, polygons=[], labels=labels, - saliency_maps=list(saliency_maps), - feature_vectors=list(feature_vectors), + saliency_map=list(saliency_map), + feature_vector=list(feature_vector), ) return InstanceSegBatchPredEntity( @@ -659,8 +659,8 @@ def _customize_outputs( masks=masks, polygons=[], labels=labels, - saliency_maps=predicted_s_maps, - feature_vectors=predicted_f_vectors, + saliency_map=predicted_s_maps, + feature_vector=predicted_f_vectors, ) return InstanceSegBatchPredEntity( diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index d498bb75b65..c56c440a883 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -186,8 +186,8 @@ def _customize_outputs( imgs_info=inputs.imgs_info, scores=[], masks=masks, - saliency_maps=explain_results, - feature_vectors=[], + saliency_map=explain_results, + feature_vector=[], ) return SegBatchPredEntity( @@ -260,8 +260,8 @@ def _customize_outputs( imgs_info=inputs.imgs_info, scores=[], masks=[tv_tensors.Mask(mask.resultImage) for mask in outputs], - saliency_maps=predicted_s_maps, - feature_vectors=predicted_f_vectors, + saliency_map=predicted_s_maps, + feature_vector=predicted_f_vectors, ) return SegBatchPredEntity( diff --git a/src/otx/core/model/utils/mmpretrain.py b/src/otx/core/model/utils/mmpretrain.py index deaed816e25..ff25374352b 100644 --- a/src/otx/core/model/utils/mmpretrain.py +++ b/src/otx/core/model/utils/mmpretrain.py @@ -5,14 +5,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import types +from typing import TYPE_CHECKING, Callable, Generic +import torch from mmpretrain.models.utils import ClsDataPreprocessor as _ClsDataPreprocessor from mmpretrain.registry import MODELS +from otx.algo.hooks.recording_forward_hook import get_feature_vector +from otx.core.data.entity.base import T_OTXBatchDataEntity, T_OTXBatchPredEntity from otx.core.utils.build import build_mm_model, get_classification_layers if TYPE_CHECKING: + from mmpretrain.models.classifiers.image import ImageClassifier + from mmpretrain.structures import DataSample from omegaconf import DictConfig from torch import device, nn @@ -48,3 +54,138 @@ def create_model(config: DictConfig, load_from: str | None = None) -> tuple[nn.M """ classification_layers = get_classification_layers(config, MODELS, "model.") return build_mm_model(config, MODELS, load_from), classification_layers + + +class ExplainableMixInMMPretrainModel(Generic[T_OTXBatchPredEntity, T_OTXBatchDataEntity]): + """Mix-in class to support XAI feature for MM Pretrained models commonly. + + This is a mix-in class that cannot be used as standalone. The correct usage of it will be like + ```python + class MobileNetV3ForMulticlassCls(ExplainableMixInMMPretrainModel, MMPretrainMulticlassClsModel): + ``` + """ + + explain_mode: bool + num_classes: int + model: ImageClassifier + + @property + def has_gap(self) -> bool: + """Defines if GAP is used right after backbone. + + Note: + Can be redefined at the model's level. + """ + return True + + @torch.no_grad() + def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor: + """Performs model's neck and head forward. + + Note: + Can be redefined at the model's level. + """ + if (neck := getattr(self.model, "neck", None)) is None: + raise ValueError + if (head := getattr(self.model, "head", None)) is None: + raise ValueError + + output = neck(x) + return head([output]) + + @staticmethod + def _forward_explain_image_classifier( + self: ImageClassifier, + inputs: torch.Tensor, + data_samples: list[DataSample] | None = None, + mode: str = "tensor", + ) -> dict[str, torch.Tensor]: + """Forward func of the ImageClassifier instance, which located in ExplainableOTXClsModel().model. + + Note: + Can be redefined at the model's level. + """ + x = self.backbone(inputs) + backbone_feat = x + + feature_vector = self.feature_vector_fn(backbone_feat) + saliency_map = self.explain_fn(backbone_feat) + + if self.with_neck: + x = self.neck(x) + + if mode == "tensor": + logits = self.head(x) if self.with_head else x + elif mode == "predict": + logits = self.head.predict(x, data_samples) + else: + msg = f'Invalid mode "{mode}".' + raise RuntimeError(msg) + + return { + "logits": logits, + "feature_vector": feature_vector, + "saliency_map": saliency_map, + } + + def get_explain_fn(self) -> Callable: + """Returns explain function. + + Note: + Can be redefined at the model's level. + """ + from otx.algo.hooks.recording_forward_hook import ReciproCAMHook + + explainer = ReciproCAMHook( + self.head_forward_fn, + num_classes=self.num_classes, + optimize_gap=self.has_gap, + ) + return explainer.func + + def forward_explain( + self, + inputs: T_OTXBatchDataEntity, + ) -> T_OTXBatchPredEntity: + """Model forward function.""" + forward_func: Callable[[T_OTXBatchDataEntity], T_OTXBatchPredEntity] | None = getattr(self, "forward", None) + + if forward_func is None: + msg = ( + "This instance has no forward function. " + "Did you attach this mixin into a class derived from OTXModel?" + ) + raise RuntimeError(msg) + + try: + self._reset_model_forward() + return forward_func(inputs) + finally: + self._restore_model_forward() + + def _reset_model_forward(self) -> None: + # TODO(vinnamkim): This will be revisited by the export refactoring + if not self.explain_mode: + return + + self.model.feature_vector_fn = get_feature_vector + self.model.explain_fn = self.get_explain_fn() + forward_with_explain = self._forward_explain_image_classifier + + self.original_model_forward = self.model.forward + + func_type = types.MethodType + self.model.forward = func_type(forward_with_explain, self.model) + + def _restore_model_forward(self) -> None: + # TODO(vinnamkim): This will be revisited by the export refactoring + if not self.explain_mode: + return + + if not self.original_model_forward: + msg = "Original model forward was not saved." + raise RuntimeError(msg) + + func_type = types.MethodType + self.model.forward = func_type(self.original_model_forward, self.model) + self.original_model_forward = None diff --git a/src/otx/core/types/explain.py b/src/otx/core/types/explain.py index 5ed8b94bea8..cb17cbca020 100644 --- a/src/otx/core/types/explain.py +++ b/src/otx/core/types/explain.py @@ -6,6 +6,11 @@ from __future__ import annotations from enum import Enum +from typing import Sequence + +import torch + +FeatureMapType = torch.Tensor | Sequence[torch.Tensor] class TargetExplainGroup(str, Enum): diff --git a/src/otx/engine/engine.py b/src/otx/engine/engine.py index c00a0fe2f4b..32b93a49497 100644 --- a/src/otx/engine/engine.py +++ b/src/otx/engine/engine.py @@ -67,9 +67,11 @@ def override_metric_callable(model: OTXModel, new_metric_callable: MetricCallabl return orig_metric_callable = model.metric_callable - model.metric_callable = new_metric_callable - yield model - model.metric_callable = orig_metric_callable + try: + model.metric_callable = new_metric_callable + yield model + finally: + model.metric_callable = orig_metric_callable class Engine: @@ -445,15 +447,19 @@ def predict( loaded_checkpoint = torch.load(checkpoint) model.load_state_dict(loaded_checkpoint) - model.explain_mode = explain - self._build_trainer(**kwargs) - predict_result = self.trainer.predict( - model=model, - dataloaders=datamodule, - return_predictions=return_predictions, - ) + curr_explain_mode = model.explain_mode + + try: + model.explain_mode = explain + predict_result = self.trainer.predict( + model=model, + dataloaders=datamodule, + return_predictions=return_predictions, + ) + finally: + model.explain_mode = curr_explain_mode if explain: if explain_config is None: @@ -461,7 +467,6 @@ def predict( predict_result = process_saliency_maps_in_pred_entity(predict_result, explain_config) - model.explain_mode = False return predict_result def export( diff --git a/tests/integration/api/test_engine_api.py b/tests/integration/api/test_engine_api.py index 5ecf36e0d41..446cb500d3d 100644 --- a/tests/integration/api/test_engine_api.py +++ b/tests/integration/api/test_engine_api.py @@ -105,7 +105,7 @@ def test_engine_from_config( # Predict Torch model with explain predictions = engine.predict(explain=True) - assert len(predictions[0].saliency_maps) > 0 + assert len(predictions[0].saliency_map) > 0 # Export IR model with explain exported_model_with_explain = engine.export(explain=True) @@ -114,13 +114,13 @@ def test_engine_from_config( # Infer IR Model with explain: predict predictions = engine.predict(explain=True, checkpoint=exported_model_with_explain, accelerator="cpu") assert len(predictions) > 0 - sal_maps_from_prediction = predictions[0].saliency_maps + sal_maps_from_prediction = predictions[0].saliency_map assert len(sal_maps_from_prediction) > 0 # Infer IR Model with explain: explain explain_results = engine.explain(checkpoint=exported_model_with_explain, accelerator="cpu") - assert len(explain_results[0].saliency_maps) > 0 - sal_maps_from_explain = explain_results[0].saliency_maps + assert len(explain_results[0].saliency_map) > 0 + sal_maps_from_explain = explain_results[0].saliency_map assert (sal_maps_from_prediction[0][0] == sal_maps_from_explain[0][0]).all() diff --git a/tests/integration/api/test_xai.py b/tests/integration/api/test_xai.py index 63e3bc4890c..26f7ace117d 100644 --- a/tests/integration/api/test_xai.py +++ b/tests/integration/api/test_xai.py @@ -55,6 +55,7 @@ def test_forward_explain( predict_result = engine.predict() assert isinstance(predict_result[0], OTXBatchPredEntity) + assert not predict_result[0].has_xai_outputs predict_result_explain = engine.predict(explain=True) assert isinstance(predict_result_explain[0], OTXBatchPredEntity) @@ -109,8 +110,8 @@ def test_predict_with_explain( predict_result_explain_torch = engine.predict(explain=True) assert isinstance(predict_result_explain_torch[0], OTXBatchPredEntity) assert predict_result_explain_torch[0].has_xai_outputs - assert predict_result_explain_torch[0].saliency_maps is not None - assert isinstance(predict_result_explain_torch[0].saliency_maps[0], dict) + assert predict_result_explain_torch[0].saliency_map is not None + assert isinstance(predict_result_explain_torch[0].saliency_map[0], dict) # Export with explain ckpt_path = tmp_path / "checkpoint.ckpt" @@ -138,10 +139,10 @@ def test_predict_with_explain( predict_result_explain_ov = engine.predict(checkpoint=exported_model_path, explain=True) assert isinstance(predict_result_explain_ov[0], OTXBatchPredEntity) assert predict_result_explain_ov[0].has_xai_outputs - assert predict_result_explain_ov[0].saliency_maps is not None - assert isinstance(predict_result_explain_ov[0].saliency_maps[0], dict) - assert predict_result_explain_ov[0].feature_vectors is not None - assert isinstance(predict_result_explain_ov[0].feature_vectors[0], np.ndarray) + assert predict_result_explain_ov[0].saliency_map is not None + assert isinstance(predict_result_explain_ov[0].saliency_map[0], dict) + assert predict_result_explain_ov[0].feature_vector is not None + assert isinstance(predict_result_explain_ov[0].feature_vector[0], np.ndarray) if task == "instance_segmentation" or "atss_r50_fpn" in recipe: # For instance segmentation and atss_r50_fpn batch_size for Torch task 1, for OV 2. @@ -151,8 +152,8 @@ def test_predict_with_explain( # TODO(gzalessk): remove this if statement when the issue is resolved return - maps_torch = predict_result_explain_torch[0].saliency_maps - maps_ov = predict_result_explain_ov[0].saliency_maps + maps_torch = predict_result_explain_torch[0].saliency_map + maps_ov = predict_result_explain_ov[0].saliency_map assert len(maps_torch) == len(maps_ov) diff --git a/tests/integration/cli/test_cli.py b/tests/integration/cli/test_cli.py index ad48e464a4f..a5d7a63f9fa 100644 --- a/tests/integration/cli/test_cli.py +++ b/tests/integration/cli/test_cli.py @@ -316,9 +316,9 @@ def test_otx_explain_e2e( (p for p in outputs_dir.iterdir() if p.is_dir() and p.name != ".latest"), key=lambda p: p.stat().st_mtime, ) - assert (latest_dir / "saliency_maps").exists() - saliency_maps = sorted((latest_dir / "saliency_maps").glob(pattern="*.png")) - sal_map = cv2.imread(str(saliency_maps[0])) + assert (latest_dir / "saliency_map").exists() + saliency_map = sorted((latest_dir / "saliency_map").glob(pattern="*.png")) + sal_map = cv2.imread(str(saliency_map[0])) assert sal_map.shape[0] > 0 assert sal_map.shape[1] > 0 @@ -354,7 +354,7 @@ def test_otx_explain_e2e( } test_case_name = task + "_" + model_name if test_case_name in reference_sal_vals: - actual_sal_vals = cv2.imread(str(latest_dir / "saliency_maps" / reference_sal_vals[test_case_name][1])) + actual_sal_vals = cv2.imread(str(latest_dir / "saliency_map" / reference_sal_vals[test_case_name][1])) if test_case_name == "instance_segmentation_maskrcnn_efficientnetb2b": # Take corner values due to map sparsity of InstSeg actual_sal_vals = (actual_sal_vals[-10:, -1, -1]).astype(np.uint16) diff --git a/tests/unit/algo/classification/conftest.py b/tests/unit/algo/classification/conftest.py index 61f34510b5c..628dac11428 100644 --- a/tests/unit/algo/classification/conftest.py +++ b/tests/unit/algo/classification/conftest.py @@ -3,12 +3,15 @@ # from __future__ import annotations +from dataclasses import asdict + import pytest import torch from mmpretrain.structures import DataSample from omegaconf import DictConfig from otx.core.data.dataset.classification import MulticlassClsBatchDataEntity from otx.core.data.entity.base import ImageInfo +from otx.core.data.entity.classification import HlabelClsBatchDataEntity, MultilabelClsBatchDataEntity from otx.core.types.label import HLabelInfo from torchvision import tv_tensors @@ -153,6 +156,27 @@ def fxt_multiclass_cls_batch_data_entity() -> MulticlassClsBatchDataEntity: ) +@pytest.fixture() +def fxt_multilabel_cls_batch_data_entity( + fxt_multiclass_cls_batch_data_entity, + fxt_hlabel_data, +) -> MultilabelClsBatchDataEntity: + return MultilabelClsBatchDataEntity( + batch_size=2, + images=fxt_multiclass_cls_batch_data_entity.images, + imgs_info=fxt_multiclass_cls_batch_data_entity.imgs_info, + labels=[ + torch.nn.functional.one_hot(label, num_classes=fxt_hlabel_data.num_classes).flatten() + for label in fxt_multiclass_cls_batch_data_entity.labels + ], + ) + + +@pytest.fixture() +def fxt_hlabel_cls_batch_data_entity(fxt_multilabel_cls_batch_data_entity) -> HlabelClsBatchDataEntity: + return HlabelClsBatchDataEntity(**asdict(fxt_multilabel_cls_batch_data_entity)) + + @pytest.fixture() def fxt_config_mock() -> DictConfig: pseudo_model_config = { diff --git a/tests/unit/algo/classification/test_deit_tiny.py b/tests/unit/algo/classification/test_deit_tiny.py index ae1b15141e0..4c0095a62d0 100644 --- a/tests/unit/algo/classification/test_deit_tiny.py +++ b/tests/unit/algo/classification/test_deit_tiny.py @@ -2,39 +2,72 @@ # SPDX-License-Identifier: Apache-2.0 # +from pathlib import Path + import pytest -import torch from otx.algo.classification.deit_tiny import ( DeitTinyForHLabelCls, DeitTinyForMulticlassCls, DeitTinyForMultilabelCls, ) from otx.algo.utils.support_otx_v1 import OTXv1Helper +from otx.core.data.entity.base import OTXBatchLossEntity +from otx.core.types.export import OTXExportFormatType +from otx.core.types.precision import OTXPrecisionType class TestDeitTiny: - @pytest.mark.parametrize( - "model_cls", - [DeitTinyForMulticlassCls, DeitTinyForMultilabelCls, DeitTinyForHLabelCls], + @pytest.fixture( + params=[ + (DeitTinyForMulticlassCls, "fxt_multiclass_cls_batch_data_entity"), + (DeitTinyForMultilabelCls, "fxt_multilabel_cls_batch_data_entity"), + (DeitTinyForHLabelCls, "fxt_hlabel_cls_batch_data_entity"), + ], + ids=["multiclass", "multilabel", "hlabel"], ) - def test_deit_tiny(self, model_cls, mocker, fxt_hlabel_data): + def fxt_model_and_input(self, request, fxt_hlabel_data): + model_cls, input_fxt_name = request.param + fxt_input = request.getfixturevalue(input_fxt_name) num_classes = fxt_hlabel_data.num_classes if model_cls == DeitTinyForHLabelCls: model = model_cls(hlabel_info=fxt_hlabel_data) else: model = model_cls(num_classes=num_classes) - model.model.explain_fn = model.get_explain_fn() - assert model._optimization_config["model_type"] == "transformer" + return model, fxt_input + + @pytest.mark.parametrize("explain_mode", [True, False]) + def test_deit_tiny(self, fxt_model_and_input, explain_mode, mocker): + fxt_model, fxt_input = fxt_model_and_input + + fxt_model.train() + assert isinstance(fxt_model.forward(fxt_input), OTXBatchLossEntity) - assert model.head_forward_fn(torch.randn([1, 24, 192])).shape == torch.Size([1, num_classes]) + fxt_model.eval() + assert not isinstance(fxt_model.forward(fxt_input), OTXBatchLossEntity) - out = model._forward_explain_image_classifier(model.model, torch.randn(1, 3, 24, 24)) - assert out["logits"].shape == torch.Size([1, num_classes]) - assert out["feature_vector"].shape == torch.Size([1, 192]) - assert out["saliency_map"].shape == torch.Size([1, num_classes, 2, 2]) + fxt_model.explain_mode = explain_mode + preds = fxt_model.predict_step(fxt_input, batch_idx=0) + assert len(preds.labels) == fxt_input.batch_size + assert len(preds.scores) == fxt_input.batch_size + assert preds.has_xai_outputs == explain_mode mock_load_ckpt = mocker.patch.object(OTXv1Helper, "load_cls_effnet_b0_ckpt") - model.load_from_otx_v1_ckpt({}) + fxt_model.load_from_otx_v1_ckpt({}) mock_load_ckpt.assert_called_once_with({}, "multiclass", "model.model.") + + @pytest.mark.parametrize("explain_mode", [True, False]) + def test_export(self, fxt_model_and_input, explain_mode, tmpdir): + base_name = "exported_model" + + fxt_model, _ = fxt_model_and_input + fxt_model.eval() + fxt_model.explain_mode = explain_mode + + fxt_model.export( + output_dir=Path(tmpdir), + base_name=base_name, + export_format=OTXExportFormatType.OPENVINO, + precision=OTXPrecisionType.FP16, + ) diff --git a/tests/unit/algo/classification/test_torchvision_model.py b/tests/unit/algo/classification/test_torchvision_model.py index fe74195da06..295080fd40c 100644 --- a/tests/unit/algo/classification/test_torchvision_model.py +++ b/tests/unit/algo/classification/test_torchvision_model.py @@ -1,46 +1,33 @@ import pytest import torch from otx.algo.classification.torchvision_model import OTXTVModel, TVModelWithLossComputation -from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity -from otx.core.data.entity.classification import ( - MulticlassClsBatchDataEntity, - MulticlassClsBatchPredEntity, -) +from otx.core.data.entity.base import OTXBatchLossEntity +from otx.core.data.entity.classification import MulticlassClsBatchPredEntity @pytest.fixture() def fxt_tv_model(): - return OTXTVModel(backbone="resnet50", num_classes=10) - - -@pytest.fixture() -def fxt_inputs(): - return MulticlassClsBatchDataEntity( - batch_size=16, - images=torch.randn(16, 3, 224, 224), - imgs_info=[ImageInfo(img_idx=i, img_shape=(224, 224), ori_shape=(224, 224)) for i in range(16)], - labels=[torch.randint(0, 10, (16,))], - ) + return OTXTVModel(backbone="mobilenet_v3_small", num_classes=10) class TestOTXTVModel: def test_create_model(self, fxt_tv_model): assert isinstance(fxt_tv_model.model, TVModelWithLossComputation) - def test_customize_inputs(self, fxt_tv_model, fxt_inputs): - outputs = fxt_tv_model._customize_inputs(fxt_inputs) + def test_customize_inputs(self, fxt_tv_model, fxt_multiclass_cls_batch_data_entity): + outputs = fxt_tv_model._customize_inputs(fxt_multiclass_cls_batch_data_entity) assert "images" in outputs assert "labels" in outputs assert "mode" in outputs - def test_customize_outputs(self, fxt_tv_model, fxt_inputs): - outputs = torch.randn(16, 10) + def test_customize_outputs(self, fxt_tv_model, fxt_multiclass_cls_batch_data_entity): + outputs = torch.randn(2, 10) fxt_tv_model.training = True - preds = fxt_tv_model._customize_outputs(outputs, fxt_inputs) + preds = fxt_tv_model._customize_outputs(outputs, fxt_multiclass_cls_batch_data_entity) assert isinstance(preds, OTXBatchLossEntity) fxt_tv_model.training = False - preds = fxt_tv_model._customize_outputs(outputs, fxt_inputs) + preds = fxt_tv_model._customize_outputs(outputs, fxt_multiclass_cls_batch_data_entity) assert isinstance(preds, MulticlassClsBatchPredEntity) def test_export_parameters(self, fxt_tv_model): @@ -55,19 +42,14 @@ def test_export_parameters(self, fxt_tv_model): assert "mean" in params assert "std" in params - def test_forward_explain_image_classifier(self, fxt_tv_model): - images = torch.randn(16, 3, 224, 224) - fxt_tv_model._explain_mode = True - fxt_tv_model._reset_model_forward() - outputs = fxt_tv_model._forward_explain_image_classifier(fxt_tv_model.model, images) - assert "logits" in outputs - assert "feature_vector" in outputs - assert "saliency_map" in outputs + @pytest.mark.parametrize("explain_mode", [True, False]) + def test_predict_step(self, fxt_tv_model: OTXTVModel, fxt_multiclass_cls_batch_data_entity, explain_mode): + fxt_tv_model.eval() + fxt_tv_model.explain_mode = explain_mode + outputs = fxt_tv_model.predict_step(batch=fxt_multiclass_cls_batch_data_entity, batch_idx=0) - def test_head_forward_fn(self, fxt_tv_model): - x = torch.randn(16, 2048) - output = fxt_tv_model.head_forward_fn(x) - assert output.shape == (16, 10) + assert isinstance(outputs, MulticlassClsBatchPredEntity) + assert outputs.has_xai_outputs == explain_mode def test_freeze_backbone(self): freezed_model = OTXTVModel(backbone="resnet50", num_classes=10, freeze_backbone=True) diff --git a/tests/unit/algo/hooks/test_saliency_map_dumping.py b/tests/unit/algo/hooks/test_saliency_map_dumping.py index 3f790981d02..643c04b4c43 100644 --- a/tests/unit/algo/hooks/test_saliency_map_dumping.py +++ b/tests/unit/algo/hooks/test_saliency_map_dumping.py @@ -36,8 +36,8 @@ def test_sal_map_dump( imgs_info=IMGS_INFO, scores=None, labels=None, - saliency_maps=SALIENCY_MAPS, - feature_vectors=None, + saliency_map=SALIENCY_MAPS, + feature_vector=None, ), ] @@ -48,7 +48,7 @@ def test_sal_map_dump( output_dir=tmp_path, ) - saliency_maps_paths = sorted((tmp_path / "saliency_maps").glob(pattern="*.png")) + saliency_maps_paths = sorted((tmp_path / "saliency_map").glob(pattern="*.png")) assert len(saliency_maps_paths) == NUM_CLASSES * BATCH_SIZE diff --git a/tests/unit/algo/hooks/test_saliency_map_processing.py b/tests/unit/algo/hooks/test_saliency_map_processing.py index fdc8f2739b5..2f3320ebcbe 100644 --- a/tests/unit/algo/hooks/test_saliency_map_processing.py +++ b/tests/unit/algo/hooks/test_saliency_map_processing.py @@ -107,8 +107,8 @@ def _get_pred_result_multiclass(pred_labels) -> MulticlassClsBatchPredEntity: imgs_info=IMGS_INFO, scores=None, labels=pred_labels, - saliency_maps=SALIENCY_MAPS, - feature_vectors=None, + saliency_map=SALIENCY_MAPS, + feature_vector=None, ) @@ -119,8 +119,8 @@ def _get_pred_result_multilabel(pred_labels) -> MultilabelClsBatchPredEntity: imgs_info=IMGS_INFO, scores=None, labels=pred_labels, - saliency_maps=SALIENCY_MAPS, - feature_vectors=None, + saliency_map=SALIENCY_MAPS, + feature_vector=None, ) @@ -137,9 +137,9 @@ def test_process_saliency_maps_in_pred_entity_multiclass() -> None: ) for i in range(len(predict_result)): - assert isinstance(predict_result[i].saliency_maps, list) - assert isinstance(predict_result[i].saliency_maps[0], dict) - processed_saliency_maps = predict_result[i].saliency_maps + assert isinstance(predict_result[i].saliency_map, list) + assert isinstance(predict_result[i].saliency_map[0], dict) + processed_saliency_maps = predict_result[i].saliency_map assert all(len(s_map_dict) == 1 for s_map_dict in processed_saliency_maps) @@ -156,7 +156,7 @@ def test_process_saliency_maps_in_pred_entity_multilabel() -> None: ) for i in range(len(predict_result)): - assert isinstance(predict_result[i].saliency_maps, list) - assert isinstance(predict_result[i].saliency_maps[0], dict) - processed_saliency_maps = predict_result[i].saliency_maps + assert isinstance(predict_result[i].saliency_map, list) + assert isinstance(predict_result[i].saliency_map[0], dict) + processed_saliency_maps = predict_result[i].saliency_map assert all(len(s_map_dict) == len(PRED_LABELS[i]) for (i, s_map_dict) in enumerate(processed_saliency_maps)) diff --git a/tests/unit/algo/hooks/test_xai_hooks.py b/tests/unit/algo/hooks/test_xai_hooks.py index 65869496356..55d9c63f829 100644 --- a/tests/unit/algo/hooks/test_xai_hooks.py +++ b/tests/unit/algo/hooks/test_xai_hooks.py @@ -24,8 +24,8 @@ def test_activationmap() -> None: feature_map = torch.zeros((1, 10, 5, 5)) - saliency_maps = hook.func(feature_map) - assert saliency_maps.size() == torch.Size([1, 5, 5]) + saliency_map = hook.func(feature_map) + assert saliency_map.size() == torch.Size([1, 5, 5]) hook.recording_forward(None, None, feature_map) assert len(hook.records) == 1 @@ -52,8 +52,8 @@ def cls_head_forward_fn(_) -> None: feature_map = torch.zeros((1, 10, 5, 5)) - saliency_maps = hook.func(feature_map) - assert saliency_maps.size() == torch.Size([1, 2, 5, 5]) + saliency_map = hook.func(feature_map) + assert saliency_map.size() == torch.Size([1, 2, 5, 5]) hook.recording_forward(None, None, feature_map) assert len(hook.records) == 1 @@ -78,8 +78,8 @@ def cls_head_forward_fn(_) -> None: feature_map = torch.zeros((1, 197, 192)) - saliency_maps = hook.func(feature_map) - assert saliency_maps.size() == torch.Size([1, 2, 14, 14]) + saliency_map = hook.func(feature_map) + assert saliency_map.size() == torch.Size([1, 2, 14, 14]) hook.recording_forward(None, None, feature_map) assert len(hook.records) == 1 @@ -102,8 +102,8 @@ def test_detclassprob() -> None: backbone_out = torch.zeros((1, 5, 2, 2, 2)) - saliency_maps = hook.func(backbone_out) - assert saliency_maps.size() == torch.Size([5, 2, 2, 2]) + saliency_map = hook.func(backbone_out) + assert saliency_map.size() == torch.Size([5, 2, 2, 2]) def test_maskrcnn() -> None: @@ -137,6 +137,6 @@ def test_maskrcnn() -> None: ) # 2 images - saliency_maps = hook.func([pred, pred]) - assert len(saliency_maps) == 2 - assert saliency_maps[0].shape == (2, 10, 10) + saliency_map = hook.func([pred, pred]) + assert len(saliency_map) == 2 + assert saliency_map[0].shape == (2, 10, 10)