Skip to content

Commit

Permalink
Enable export of feature vectors for semantic segmentation task (#4055)
Browse files Browse the repository at this point in the history
  • Loading branch information
kprokofi authored Oct 23, 2024
1 parent a837a1d commit decfdbe
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 29 deletions.
4 changes: 4 additions & 0 deletions src/otx/algo/segmentation/huggingface_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,8 @@ def _exporter(self) -> OTXModelExporter:

def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | dict[str, torch.Tensor]:
"""Model forward function used for the model tracing during model exportation."""
if self.explain_mode:
msg = "Explain mode is not supported for this model."
raise NotImplementedError(msg)

return self.model(image)
2 changes: 1 addition & 1 deletion src/otx/algo/segmentation/litehrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def _exporter(self) -> OTXModelExporter:
swap_rgb=False,
via_onnx=False,
onnx_export_configuration={"operator_export_type": OperatorExportTypes.ONNX_ATEN_FALLBACK},
output_names=None,
output_names=["preds", "feature_vector"] if self.explain_mode else None,
)

@property
Expand Down
15 changes: 12 additions & 3 deletions src/otx/algo/segmentation/segmentors/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import torch.nn.functional as f
from torch import Tensor, nn

from otx.algo.explain.explain_algo import feature_vector_fn

if TYPE_CHECKING:
from otx.core.data.entity.base import ImageInfo

Expand Down Expand Up @@ -58,7 +60,7 @@ def forward(
- If mode is "predict", returns the predicted outputs.
- Otherwise, returns the model outputs after interpolation.
"""
outputs = self.extract_features(inputs)
enc_feats, outputs = self.extract_features(inputs)
outputs = f.interpolate(outputs, size=inputs.size()[2:], mode="bilinear", align_corners=True)

if mode == "tensor":
Expand All @@ -76,12 +78,19 @@ def forward(
if mode == "predict":
return outputs.argmax(dim=1)

if mode == "explain":
feature_vector = feature_vector_fn(enc_feats)
return {
"preds": outputs,
"feature_vector": feature_vector,
}

return outputs

def extract_features(self, inputs: Tensor) -> Tensor:
def extract_features(self, inputs: Tensor) -> tuple[Tensor, Tensor]:
"""Extract features from the backbone and head."""
enc_feats = self.backbone(inputs)
return self.decode_head(enc_feats)
return enc_feats, self.decode_head(enc_feats)

def calculate_loss(
self,
Expand Down
1 change: 0 additions & 1 deletion src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ def __init__(
self.input_size = input_size
self.classification_layers: dict[str, dict[str, Any]] = {}
self.model = self._create_model()
self._explain_mode = False
self.optimizer_callable = ensure_callable(optimizer)
self.scheduler_callable = ensure_callable(scheduler)
self.metric_callable = ensure_callable(metric)
Expand Down
81 changes: 63 additions & 18 deletions src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,12 @@ def _build_model(self) -> nn.Module:
"""

def _customize_inputs(self, entity: SegBatchDataEntity) -> dict[str, Any]:
mode = "loss" if self.training else "predict"
if self.training:
mode = "loss"
elif self.explain_mode:
mode = "explain"
else:
mode = "predict"

if self.train_type == OTXTrainType.SEMI_SUPERVISED and mode == "loss":
if not isinstance(entity, dict):
Expand Down Expand Up @@ -155,6 +160,16 @@ def _customize_outputs(
losses[k] = v
return losses

if self.explain_mode:
return SegBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=outputs["preds"],
feature_vector=outputs["feature_vector"],
)

return SegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
Expand Down Expand Up @@ -199,14 +214,24 @@ def _exporter(self) -> OTXModelExporter:
swap_rgb=False,
via_onnx=False,
onnx_export_configuration=None,
output_names=None,
output_names=["preds", "feature_vector"] if self.explain_mode else None,
)

def _convert_pred_entity_to_compute_metric(
self,
preds: SegBatchPredEntity,
inputs: SegBatchDataEntity,
) -> MetricInput:
"""Convert prediction and input entities to a format suitable for metric computation.
Args:
preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks.
inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks.
Returns:
MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys
corresponding to the predicted and target masks for metric evaluation.
"""
return [
{
"preds": pred_mask,
Expand All @@ -228,8 +253,26 @@ def _dispatch_label_info(label_info: LabelInfoTypes) -> LabelInfo:

def forward_for_tracing(self, image: Tensor) -> Tensor | dict[str, Tensor]:
"""Model forward function used for the model tracing during model exportation."""
raw_outputs = self.model(inputs=image, mode="tensor")
return torch.softmax(raw_outputs, dim=1)
if self.explain_mode:
outputs = self.model(inputs=image, mode="explain")
outputs["preds"] = torch.softmax(outputs["preds"], dim=1)
return outputs

outputs = self.model(inputs=image, mode="tensor")
return torch.softmax(outputs, dim=1)

def forward_explain(self, inputs: SegBatchDataEntity) -> SegBatchPredEntity:
"""Model forward explain function."""
outputs = self.model(inputs=inputs.images, mode="explain")

return SegBatchPredEntity(
batch_size=len(outputs["preds"]),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=outputs["preds"],
feature_vector=outputs["feature_vector"],
)

def get_dummy_input(self, batch_size: int = 1) -> SegBatchDataEntity:
"""Returns a dummy input for semantic segmentation model."""
Expand Down Expand Up @@ -308,32 +351,34 @@ def _customize_outputs(
outputs: list[ImageResultWithSoftPrediction],
inputs: SegBatchDataEntity,
) -> SegBatchPredEntity | OTXBatchLossEntity:
if outputs and outputs[0].saliency_map.size != 1:
predicted_s_maps = [out.saliency_map for out in outputs]
predicted_f_vectors = [out.feature_vector for out in outputs]
return SegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=[tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs],
saliency_map=predicted_s_maps,
feature_vector=predicted_f_vectors,
)

masks = [tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs]
predicted_f_vectors = (
[out.feature_vector for out in outputs] if outputs and outputs[0].feature_vector.size != 1 else []
)
return SegBatchPredEntity(
batch_size=len(outputs),
images=inputs.images,
imgs_info=inputs.imgs_info,
scores=[],
masks=[tv_tensors.Mask(mask.resultImage, device=self.device) for mask in outputs],
masks=masks,
feature_vector=predicted_f_vectors,
)

def _convert_pred_entity_to_compute_metric(
self,
preds: SegBatchPredEntity,
inputs: SegBatchDataEntity,
) -> MetricInput:
"""Convert prediction and input entities to a format suitable for metric computation.
Args:
preds (SegBatchPredEntity): The predicted segmentation batch entity containing predicted masks.
inputs (SegBatchDataEntity): The input segmentation batch entity containing ground truth masks.
Returns:
MetricInput: A list of dictionaries where each dictionary contains 'preds' and 'target' keys
corresponding to the predicted and target masks for metric evaluation.
"""
return [
{
"preds": pred_mask,
Expand Down
4 changes: 2 additions & 2 deletions tests/e2e/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ def test_otx_e2e_cli(
# 5) otx export with XAI
if "instance_segmentation/rtmdet_inst_tiny" in recipe:
return
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
return # Supported only for classification, detection and instance segmentation task.
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation", "semantic_segmentation"]):
return # Supported only for classification, detection and segmentation tasks.

unsupported_models = ["dino", "rtdetr"]
if any(model in model_name for model in unsupported_models):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,8 +242,8 @@ def test_otx_e2e(
# 5) otx export with XAI
if "instance_segmentation/rtmdet_inst_tiny" in recipe:
return
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation"]):
return # Supported only for classification, detection and instance segmentation task.
if ("_cls" not in task) and (task not in ["detection", "instance_segmentation", "semantic_segmentation"]):
return # Supported only for classification, detection and segmentation tasks.

if "dino" in model_name:
return # DINO is not supported.
Expand Down
6 changes: 4 additions & 2 deletions tests/unit/algo/segmentation/segmentors/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ def test_forward_returns_prediction(self, model, inputs):
def test_extract_features(self, model, inputs):
images = inputs[0]
features = model.extract_features(images)
assert isinstance(features, torch.Tensor)
assert features.shape == (1, 2, 256, 256)
assert isinstance(features, tuple)
assert isinstance(features[0], torch.Tensor)
assert isinstance(features[1], torch.Tensor)
assert features[1].shape == (1, 2, 256, 256)

def test_calculate_loss(self, model, inputs):
model.criterion.name = "CrossEntropyLoss"
Expand Down

0 comments on commit decfdbe

Please sign in to comment.