Skip to content

Commit

Permalink
XAI tor tiling: detection, instance segmentation (#3297)
Browse files Browse the repository at this point in the history
* XAI tiling

* Add tests

* Fix pre-commit

* Omit extra changes

* Refactor DetBatPredEntity initialization

* Initialization update

* Fix comments & integration tests

* Update doc description

* Fix merge conflicts

* Minor
  • Loading branch information
GalyaZalesskaya authored Apr 15, 2024
1 parent 44fba13 commit 61095d6
Show file tree
Hide file tree
Showing 7 changed files with 413 additions and 160 deletions.
9 changes: 6 additions & 3 deletions src/otx/algo/explain/explain_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,19 +341,22 @@ def func(
@classmethod
def average_and_normalize(
cls,
pred: InstanceData,
pred: InstanceData | dict[str, torch.Tensor],
num_classes: int,
) -> np.array:
"""Average and normalize masks in prediction per-class.
Args:
preds (InstanceData): Predictions of Instance Segmentation model.
preds (InstanceData | dict): Predictions of Instance Segmentation model.
num_classes (int): Num classes that model can predict.
Returns:
np.array: Class-wise Saliency Maps. One saliency map per each class - [class_id, H, W]
"""
masks, scores, labels = (pred.masks, pred.scores, pred.labels)
if isinstance(pred, dict):
masks, scores, labels = pred["masks"], pred["scores"], pred["labels"]
else:
masks, scores, labels = (pred.masks, pred.scores, pred.labels)
_, height, width = masks.shape

saliency_map = torch.zeros((num_classes, height, width), dtype=torch.float32, device=labels.device)
Expand Down
25 changes: 19 additions & 6 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
from otx.core.data.entity.tile import TileBatchDetDataEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity, TileBatchDetDataEntity
from otx.core.metrics import MetricCallable, MetricInput
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
Expand Down Expand Up @@ -70,26 +70,31 @@ def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity:
tile_attrs: list[list[dict[str, int | str]]] = []
merger = DetectionTileMerge(
inputs.imgs_info,
self.tile_config.iou_threshold,
self.tile_config.max_num_instances,
self.num_classes,
self.tile_config,
)
for batch_tile_attrs, batch_tile_input in inputs.unbind():
output = self.forward(batch_tile_input)
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)
if isinstance(output, OTXBatchLossEntity):
msg = "Loss output is not supported for tile merging"
raise TypeError(msg)
tile_preds.append(output)
tile_attrs.append(batch_tile_attrs)
pred_entities = merger.merge(tile_preds, tile_attrs)

return DetBatchPredEntity(
pred_entity = DetBatchPredEntity(
batch_size=inputs.batch_size,
images=[pred_entity.image for pred_entity in pred_entities],
imgs_info=[pred_entity.img_info for pred_entity in pred_entities],
scores=[pred_entity.score for pred_entity in pred_entities],
bboxes=[pred_entity.bboxes for pred_entity in pred_entities],
labels=[pred_entity.labels for pred_entity in pred_entities],
)
if self.explain_mode:
pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities]
pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities]

return pred_entity

@property
def _export_parameters(self) -> TaskLevelExportParameters:
Expand Down Expand Up @@ -190,9 +195,17 @@ def __init__(

def forward_explain(
self,
inputs: DetBatchDataEntity,
inputs: DetBatchDataEntity | TileBatchDetDataEntity,
) -> DetBatchPredEntity:
"""Model forward function."""
from otx.algo.explain.explain_algo import get_feature_vector

if isinstance(inputs, OTXTileBatchDataEntity):
return self.forward_tiles(inputs)

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

# If customize_inputs is overridden
outputs = (
self._forward_explain_detection(self.model, **self._customize_inputs(inputs))
Expand Down
24 changes: 18 additions & 6 deletions src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from openvino.model_api.tilers import InstanceSegmentationTiler
from torchvision import tv_tensors

from otx.algo.explain.explain_algo import get_feature_vector
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity
from otx.core.data.entity.tile import TileBatchInstSegDataEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity, TileBatchInstSegDataEntity
from otx.core.metrics import MetricInput
from otx.core.metrics.mean_ap import MaskRLEMeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
Expand Down Expand Up @@ -81,19 +82,19 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP
tile_attrs: list[list[dict[str, int | str]]] = []
merger = InstanceSegTileMerge(
inputs.imgs_info,
self.tile_config.iou_threshold,
self.tile_config.max_num_instances,
self.num_classes,
self.tile_config,
)
for batch_tile_attrs, batch_tile_input in inputs.unbind():
output = self.forward(batch_tile_input)
output = self.forward_explain(batch_tile_input) if self.explain_mode else self.forward(batch_tile_input)
if isinstance(output, OTXBatchLossEntity):
msg = "Loss output is not supported for tile merging"
raise TypeError(msg)
tile_preds.append(output)
tile_attrs.append(batch_tile_attrs)
pred_entities = merger.merge(tile_preds, tile_attrs)

return InstanceSegBatchPredEntity(
pred_entity = InstanceSegBatchPredEntity(
batch_size=inputs.batch_size,
images=[pred_entity.image for pred_entity in pred_entities],
imgs_info=[pred_entity.img_info for pred_entity in pred_entities],
Expand All @@ -103,6 +104,11 @@ def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchP
masks=[pred_entity.masks for pred_entity in pred_entities],
polygons=[pred_entity.polygons for pred_entity in pred_entities],
)
if self.explain_mode:
pred_entity.saliency_map = [pred_entity.saliency_map for pred_entity in pred_entities]
pred_entity.feature_vector = [pred_entity.feature_vector for pred_entity in pred_entities]

return pred_entity

@property
def _export_parameters(self) -> TaskLevelExportParameters:
Expand Down Expand Up @@ -231,9 +237,15 @@ def __init__(

def forward_explain(
self,
inputs: InstanceSegBatchDataEntity,
inputs: InstanceSegBatchDataEntity | TileBatchInstSegDataEntity,
) -> InstanceSegBatchPredEntity:
"""Model forward function."""
if isinstance(inputs, OTXTileBatchDataEntity):
return self.forward_tiles(inputs)

self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

# If customize_inputs is overridden
outputs = (
self._forward_explain_inst_seg(self.model, **self._customize_inputs(inputs))
Expand Down
Loading

0 comments on commit 61095d6

Please sign in to comment.