Skip to content

Commit

Permalink
Refactor XAI model part - classification tasks (#3242)
Browse files Browse the repository at this point in the history
* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

* FIx

Signed-off-by: Kim, Vinnam <[email protected]>

* Fix ruff

Signed-off-by: Kim, Vinnam <[email protected]>

* Update src/otx/algo/utils/xai_utils.py

Co-authored-by: Galina Zalesskaya <[email protected]>

* Update src/otx/core/model/utils/mmpretrain.py

* Rename

 - ForwardExplainMixInForMMPretrain => ExplainableMixInMMPretrainModel

Signed-off-by: Kim, Vinnam <[email protected]>

---------

Signed-off-by: Kim, Vinnam <[email protected]>
Co-authored-by: Galina Zalesskaya <[email protected]>
  • Loading branch information
vinnamkim and GalyaZalesskaya authored Apr 4, 2024
1 parent cd6702a commit acf03b9
Show file tree
Hide file tree
Showing 26 changed files with 594 additions and 439 deletions.
12 changes: 6 additions & 6 deletions src/otx/algo/classification/deit_tiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@
from otx.core.metrics.accuracy import HLabelClsMetricCallble, MultiClassClsMetricCallable, MultiLabelClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import (
ExplainableOTXClsModel,
MMPretrainHlabelClsModel,
MMPretrainMulticlassClsModel,
MMPretrainMultilabelClsModel,
)
from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo

if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from mmpretrain.models import ImageClassifier
from mmpretrain.models.classifiers import ImageClassifier
from mmpretrain.structures import DataSample

from otx.core.metrics import MetricCallable


class ExplainableDeit(ExplainableOTXClsModel):
class ForwardExplainMixInForDeit(ExplainableMixInMMPretrainModel):
"""Deit model which can attach a XAI hook."""

@torch.no_grad()
Expand Down Expand Up @@ -145,7 +145,7 @@ def _optimization_config(self) -> dict[str, Any]:
return {"model_type": "transformer"}


class DeitTinyForHLabelCls(ExplainableDeit, MMPretrainHlabelClsModel):
class DeitTinyForHLabelCls(ForwardExplainMixInForDeit, MMPretrainHlabelClsModel):
"""DeitTiny Model for hierarchical label classification task."""

def __init__(
Expand All @@ -172,7 +172,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)


class DeitTinyForMulticlassCls(ExplainableDeit, MMPretrainMulticlassClsModel):
class DeitTinyForMulticlassCls(ForwardExplainMixInForDeit, MMPretrainMulticlassClsModel):
"""DeitTiny Model for multi-label classification task."""

def __init__(
Expand All @@ -198,7 +198,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)


class DeitTinyForMultilabelCls(ExplainableDeit, MMPretrainMultilabelClsModel):
class DeitTinyForMultilabelCls(ForwardExplainMixInForDeit, MMPretrainMultilabelClsModel):
"""DeitTiny Model for multi-class classification task."""

def __init__(
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/efficientnet_b0.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MMPretrainMulticlassClsModel,
MMPretrainMultilabelClsModel,
)
from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo

Expand All @@ -24,7 +25,7 @@
from otx.core.metrics import MetricCallable


class EfficientNetB0ForHLabelCls(MMPretrainHlabelClsModel):
class EfficientNetB0ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlabelClsModel):
"""EfficientNetB0 Model for hierarchical label classification task."""

def __init__(
Expand All @@ -51,7 +52,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "hlabel", add_prefix)


class EfficientNetB0ForMulticlassCls(MMPretrainMulticlassClsModel):
class EfficientNetB0ForMulticlassCls(ExplainableMixInMMPretrainModel, MMPretrainMulticlassClsModel):
"""EfficientNetB0 Model for multi-label classification task."""

def __init__(
Expand Down Expand Up @@ -79,7 +80,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
return OTXv1Helper.load_cls_effnet_b0_ckpt(state_dict, "multiclass", add_prefix)


class EfficientNetB0ForMultilabelCls(MMPretrainMultilabelClsModel):
class EfficientNetB0ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel):
"""EfficientNetB0 Model for multi-class classification task."""

def __init__(
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/efficientnet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MMPretrainMulticlassClsModel,
MMPretrainMultilabelClsModel,
)
from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo

Expand All @@ -24,7 +25,7 @@
from otx.core.metrics import MetricCallable


class EfficientNetV2ForHLabelCls(MMPretrainHlabelClsModel):
class EfficientNetV2ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlabelClsModel):
"""EfficientNetV2 Model for hierarchical label classification task."""

def __init__(
Expand All @@ -51,7 +52,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "hlabel", add_prefix)


class EfficientNetV2ForMulticlassCls(MMPretrainMulticlassClsModel):
class EfficientNetV2ForMulticlassCls(ExplainableMixInMMPretrainModel, MMPretrainMulticlassClsModel):
"""EfficientNetV2 Model for multi-label classification task."""

def __init__(
Expand Down Expand Up @@ -79,7 +80,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
return OTXv1Helper.load_cls_effnet_v2_ckpt(state_dict, "multiclass", add_prefix)


class EfficientNetV2ForMultilabelCls(MMPretrainMultilabelClsModel):
class EfficientNetV2ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel):
"""EfficientNetV2 Model for multi-class classification task."""

def __init__(
Expand Down
7 changes: 4 additions & 3 deletions src/otx/algo/classification/mobilenet_v3_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
MMPretrainMulticlassClsModel,
MMPretrainMultilabelClsModel,
)
from otx.core.model.utils.mmpretrain import ExplainableMixInMMPretrainModel
from otx.core.schedulers import LRSchedulerListCallable
from otx.core.types.label import HLabelInfo

Expand All @@ -24,7 +25,7 @@
from otx.core.metrics import MetricCallable


class MobileNetV3ForHLabelCls(MMPretrainHlabelClsModel):
class MobileNetV3ForHLabelCls(ExplainableMixInMMPretrainModel, MMPretrainHlabelClsModel):
"""MobileNetV3 Model for hierarchical label classification task."""

def __init__(
Expand Down Expand Up @@ -58,7 +59,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "hlabel", add_prefix)


class MobileNetV3ForMulticlassCls(MMPretrainMulticlassClsModel):
class MobileNetV3ForMulticlassCls(ExplainableMixInMMPretrainModel, MMPretrainMulticlassClsModel):
"""MobileNetV3 Model for multi-label classification task."""

def __init__(
Expand Down Expand Up @@ -93,7 +94,7 @@ def load_from_otx_v1_ckpt(self, state_dict: dict, add_prefix: str = "model.model
return OTXv1Helper.load_cls_mobilenet_v3_ckpt(state_dict, "multiclass", add_prefix)


class MobileNetV3ForMultilabelCls(MMPretrainMultilabelClsModel):
class MobileNetV3ForMultilabelCls(ExplainableMixInMMPretrainModel, MMPretrainMultilabelClsModel):
"""MobileNetV3 Model for multi-class classification task."""

def __init__(
Expand Down
134 changes: 75 additions & 59 deletions src/otx/algo/classification/torchvision_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@

import torch
from torch import nn
from torchvision import tv_tensors
from torchvision.models import get_model, get_model_weights

from otx.algo.hooks.recording_forward_hook import ReciproCAMHook
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.classification import MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable
from otx.core.model.classification import OTXMulticlassClsModel
Expand Down Expand Up @@ -137,6 +139,12 @@ def __init__(
self.softmax = nn.Softmax(dim=-1)
self.loss = loss

self.explainer = ReciproCAMHook(
self._head_forward_fn,
num_classes=num_classes,
optimize_gap=True,
)

def forward(
self,
images: torch.Tensor,
Expand All @@ -161,8 +169,39 @@ def forward(
return logits
if mode == "loss":
return self.loss(logits, labels)
if mode == "explain":
return self._forward_explain(images)

return self.softmax(logits)

def _forward_explain(self, images: torch.Tensor) -> dict[str, torch.Tensor | list[torch.Tensor]]:
x = self.backbone(images)
backbone_feat = x

saliency_map = self.explainer.func(backbone_feat)

if len(x.shape) == 4 and not self.use_layer_norm_2d:
x = x.view(x.size(0), -1)

feature_vector = x

logits = self.head(x)

return {
"logits": logits,
"preds": logits.argmax(-1, keepdim=False),
"scores": self.softmax(logits),
"saliency_map": saliency_map,
"feature_vector": feature_vector,
}

@torch.no_grad()
def _head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
"""Performs model's neck and head forward."""
if len(x.shape) == 4 and not self.use_layer_norm_2d:
x = x.view(x.size(0), -1)
return self.head(x)


class OTXTVModel(OTXMulticlassClsModel):
"""OTXTVModel is that represents a TorchVision model for multiclass classification.
Expand All @@ -174,6 +213,8 @@ class OTXTVModel(OTXMulticlassClsModel):
freeze_backbone (bool, optional): Whether to freeze the backbone model. Defaults to False.
"""

model: TVModelWithLossComputation

def __init__(
self,
backbone: TVModelType,
Expand Down Expand Up @@ -207,14 +248,17 @@ def _create_model(self) -> nn.Module:
)

def _customize_inputs(self, inputs: MulticlassClsBatchDataEntity) -> dict[str, Any]:
if isinstance(inputs.images, list):
images = tv_tensors.wrap(torch.stack(inputs.images, dim=0), like=inputs.images[0])
if self.training:
mode = "loss"
elif self.explain_mode:
mode = "explain"
else:
images = inputs.images
mode = "predict"

return {
"images": images,
"images": inputs.stacked_images,
"labels": torch.cat(inputs.labels, dim=0),
"mode": "loss" if self.training else "predict",
"mode": mode,
}

def _customize_outputs(
Expand All @@ -230,23 +274,6 @@ def _customize_outputs(
scores = torch.unbind(logits, 0)
preds = logits.argmax(-1, keepdim=True).unbind(0)

if self.explain_mode:
if not isinstance(outputs, dict) or "saliency_map" not in outputs:
msg = "No saliency maps in the model output."
raise ValueError(msg)

saliency_maps = outputs["saliency_map"].detach().cpu().numpy()

return MulticlassClsBatchPredEntity(
batch_size=len(preds),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=scores,
labels=preds,
saliency_maps=list(saliency_maps),
feature_vectors=[],
)

return MulticlassClsBatchPredEntity(
batch_size=inputs.batch_size,
images=inputs.images,
Expand All @@ -255,6 +282,11 @@ def _customize_outputs(
labels=preds,
)

@property
def _exporter(self) -> OTXModelExporter:
"""Creates OTXModelExporter object that can export the model."""
return OTXNativeModelExporter(**self._export_parameters)

@property
def _export_parameters(self) -> dict[str, Any]:
"""Defines parameters required to export a particular model implementation."""
Expand All @@ -270,44 +302,28 @@ def _export_parameters(self) -> dict[str, Any]:

parameters = super()._export_parameters
parameters.update(export_params)
return parameters

@staticmethod
def _forward_explain_image_classifier(
self: TVModelWithLossComputation,
images: torch.Tensor,
labels: torch.Tensor | None = None, # noqa: ARG004
mode: str = "tensor",
) -> dict:
"""Forward func of the TVModelWithLossComputation instance."""
x = self.backbone(images)
backbone_feat = x

saliency_map = self.explain_fn(backbone_feat)

if len(x.shape) == 4 and not self.use_layer_norm_2d:
x = x.view(x.size(0), -1)

feature_vector = x
if len(feature_vector.shape) == 1:
feature_vector = feature_vector.unsqueeze(0)
return parameters

logits = self.head(x)
if mode == "predict":
logits = self.softmax(logits)
def forward_explain(self, inputs: MulticlassClsBatchDataEntity) -> MulticlassClsBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(images=inputs.stacked_images, mode="explain")

return {
"logits": logits,
"feature_vector": feature_vector,
"saliency_map": saliency_map,
}
return MulticlassClsBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
labels=outputs["preds"],
scores=outputs["scores"],
saliency_map=outputs["saliency_map"],
feature_vector=outputs["feature_vector"],
)

@torch.no_grad()
def head_forward_fn(self, x: torch.Tensor) -> torch.Tensor:
"""Performs model's neck and head forward. Can be redefined at the model's level."""
if (head := getattr(self.model, "head", None)) is None:
raise ValueError
def _reset_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.__orig_model_forward = self.model.forward
self.model.forward = self.model._forward_explain # type: ignore[assignment] # noqa: SLF001

if len(x.shape) == 4 and not self.model.use_layer_norm_2d:
x = x.view(x.size(0), -1)
return head(x)
def _restore_model_forward(self) -> None:
# TODO(vinnamkim): This will be revisited by the export refactoring
self.model.forward = self.__orig_model_forward # type: ignore[method-assign]
Loading

0 comments on commit acf03b9

Please sign in to comment.