Skip to content

Commit

Permalink
Refactor tiling data typing (#3331)
Browse files Browse the repository at this point in the history
* Refactor

* Fix

Signed-off-by: Kim, Vinnam <[email protected]>

---------

Signed-off-by: Kim, Vinnam <[email protected]>
  • Loading branch information
vinnamkim authored Apr 16, 2024
1 parent e687cd1 commit 6cce900
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 84 deletions.
19 changes: 8 additions & 11 deletions src/otx/core/data/entity/tile.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Sequence, TypeVar
from typing import TYPE_CHECKING, Generic, Sequence

from otx.core.types.task import OTXTaskType

Expand All @@ -20,12 +20,6 @@
from torchvision import tv_tensors


T_OTXTileBatchDataEntity = TypeVar(
"T_OTXTileBatchDataEntity",
bound="OTXTileBatchDataEntity",
)


@dataclass
class TileDataEntity(Generic[T_OTXDataEntity]):
"""Base data entity for tile task.
Expand Down Expand Up @@ -66,6 +60,9 @@ def task(self) -> OTXTaskType:
return OTXTaskType.DETECTION


TileAttrDictList = list[dict[str, int | str]]


@dataclass
class OTXTileBatchDataEntity(Generic[T_OTXBatchDataEntity]):
"""Base batch data entity for tile task.
Expand All @@ -82,10 +79,10 @@ class OTXTileBatchDataEntity(Generic[T_OTXBatchDataEntity]):
batch_size: int
batch_tiles: list[list[tv_tensors.Image]]
batch_tile_img_infos: list[list[ImageInfo]]
batch_tile_attr_list: list[list[dict[str, int | str]]]
batch_tile_attr_list: list[TileAttrDictList]
imgs_info: list[ImageInfo]

def unbind(self) -> list[T_OTXBatchDataEntity]:
def unbind(self) -> list[tuple[TileAttrDictList, T_OTXBatchDataEntity]]:
"""Unbind batch data entity."""
raise NotImplementedError

Expand All @@ -102,7 +99,7 @@ class TileBatchDetDataEntity(OTXTileBatchDataEntity):
bboxes: list[tv_tensors.BoundingBoxes]
labels: list[LongTensor]

def unbind(self) -> list[tuple[list[dict[str, int | str]], DetBatchDataEntity]]:
def unbind(self) -> list[tuple[TileAttrDictList, DetBatchDataEntity]]:
"""Unbind batch data entity for detection task."""
tiles = [tile for tiles in self.batch_tiles for tile in tiles]
tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos]
Expand Down Expand Up @@ -194,7 +191,7 @@ class TileBatchInstSegDataEntity(OTXTileBatchDataEntity):
masks: list[tv_tensors.Mask]
polygons: list[list[Polygon]]

def unbind(self) -> list[tuple[list[dict[str, int | str]], InstanceSegBatchDataEntity]]:
def unbind(self) -> list[tuple[TileAttrDictList, InstanceSegBatchDataEntity]]:
"""Unbind batch data entity for instance segmentation task."""
tiles = [tile for tiles in self.batch_tiles for tile in tiles]
tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos]
Expand Down
9 changes: 1 addition & 8 deletions src/otx/core/model/action_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from otx.core.data.entity.action_classification import ActionClsBatchDataEntity, ActionClsBatchPredEntity
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.tile import T_OTXTileBatchDataEntity
from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.metrics import MetricInput
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
Expand All @@ -32,13 +31,7 @@
from otx.core.metrics import MetricCallable


class OTXActionClsModel(
OTXModel[
ActionClsBatchDataEntity,
ActionClsBatchPredEntity,
T_OTXTileBatchDataEntity,
],
):
class OTXActionClsModel(OTXModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity]):
"""Base class for the action classification models used in OTX."""

def __init__(
Expand Down
9 changes: 1 addition & 8 deletions src/otx/core/model/action_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from otx.core.data.entity.action_detection import ActionDetBatchDataEntity, ActionDetBatchPredEntity
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.tile import T_OTXTileBatchDataEntity
from otx.core.metrics import MetricInput
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel
Expand All @@ -26,13 +25,7 @@
from otx.core.metrics import MetricCallable


class OTXActionDetModel(
OTXModel[
ActionDetBatchDataEntity,
ActionDetBatchPredEntity,
T_OTXTileBatchDataEntity,
],
):
class OTXActionDetModel(OTXModel[ActionDetBatchDataEntity, ActionDetBatchPredEntity]):
"""Base class for the action detection models used in OTX."""

def __init__(
Expand Down
6 changes: 3 additions & 3 deletions src/otx/core/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
T_OTXBatchDataEntity,
T_OTXBatchPredEntity,
)
from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.metrics import MetricInput, NullMetricCallable
Expand Down Expand Up @@ -81,7 +81,7 @@ def _default_scheduler_callable(
DefaultSchedulerCallable = _default_scheduler_callable


class OTXModel(LightningModule, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXTileBatchDataEntity]):
class OTXModel(LightningModule, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]):
"""Base class for the models used in OTX.
Args:
Expand Down Expand Up @@ -515,7 +515,7 @@ def _restore_model_forward(self) -> None:

def forward_tiles(
self,
inputs: T_OTXTileBatchDataEntity,
inputs: OTXTileBatchDataEntity[T_OTXBatchDataEntity],
) -> T_OTXBatchPredEntity | OTXBatchLossEntity:
"""Model forward function for tile task."""
raise NotImplementedError
Expand Down
25 changes: 3 additions & 22 deletions src/otx/core/model/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
MultilabelClsBatchDataEntity,
MultilabelClsBatchPredEntity,
)
from otx.core.data.entity.tile import T_OTXTileBatchDataEntity
from otx.core.exporter.base import OTXModelExporter
from otx.core.exporter.native import OTXNativeModelExporter
from otx.core.metrics import MetricInput
Expand All @@ -46,13 +45,7 @@
from otx.core.metrics import MetricCallable


class OTXMulticlassClsModel(
OTXModel[
MulticlassClsBatchDataEntity,
MulticlassClsBatchPredEntity,
T_OTXTileBatchDataEntity,
],
):
class OTXMulticlassClsModel(OTXModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity]):
"""Base class for the classification models used in OTX."""

def __init__(
Expand Down Expand Up @@ -241,13 +234,7 @@ def _exporter(self) -> OTXModelExporter:
### It'll be integrated after H-label classification integration with more advanced design.


class OTXMultilabelClsModel(
OTXModel[
MultilabelClsBatchDataEntity,
MultilabelClsBatchPredEntity,
T_OTXTileBatchDataEntity,
],
):
class OTXMultilabelClsModel(OTXModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity]):
"""Multi-label classification models used in OTX."""

def __init__(
Expand Down Expand Up @@ -433,13 +420,7 @@ def _exporter(self) -> OTXModelExporter:
)


class OTXHlabelClsModel(
OTXModel[
HlabelClsBatchDataEntity,
HlabelClsBatchPredEntity,
T_OTXTileBatchDataEntity,
],
):
class OTXHlabelClsModel(OTXModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity]):
"""H-label classification models used in OTX."""

def __init__(
Expand Down
11 changes: 4 additions & 7 deletions src/otx/core/model/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity, TileBatchDetDataEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity
from otx.core.metrics import MetricCallable, MetricInput
from otx.core.metrics.mean_ap import MeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
Expand All @@ -37,7 +37,7 @@
from torchmetrics import Metric


class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]):
class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity]):
"""Base class for the detection models used in OTX."""

def __init__(
Expand All @@ -57,7 +57,7 @@ def __init__(
)
self._tile_config = TileConfig()

def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity:
def forward_tiles(self, inputs: OTXTileBatchDataEntity[DetBatchDataEntity]) -> DetBatchPredEntity:
"""Unpack detection tiles.
Args:
Expand Down Expand Up @@ -193,10 +193,7 @@ def __init__(
self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

def forward_explain(
self,
inputs: DetBatchDataEntity | TileBatchDetDataEntity,
) -> DetBatchPredEntity:
def forward_explain(self, inputs: DetBatchDataEntity) -> DetBatchPredEntity:
"""Model forward function."""
from otx.algo.explain.explain_algo import get_feature_vector

Expand Down
19 changes: 5 additions & 14 deletions src/otx/core/model/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from otx.core.config.data import TileConfig
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity, TileBatchInstSegDataEntity
from otx.core.data.entity.tile import OTXTileBatchDataEntity
from otx.core.metrics import MetricInput
from otx.core.metrics.mean_ap import MaskRLEMeanAPCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
Expand All @@ -33,7 +33,7 @@
if TYPE_CHECKING:
from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable
from mmdet.models.data_preprocessors import DetDataPreprocessor
from mmdet.models.detectors.base import TwoStageDetector
from mmdet.models.detectors.two_stage import TwoStageDetector
from mmdet.structures import OptSampleList
from omegaconf import DictConfig
from openvino.model_api.models.utils import InstanceSegmentationResult
Expand All @@ -43,13 +43,7 @@
from otx.core.metrics import MetricCallable


class OTXInstanceSegModel(
OTXModel[
InstanceSegBatchDataEntity,
InstanceSegBatchPredEntity,
TileBatchInstSegDataEntity,
],
):
class OTXInstanceSegModel(OTXModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity]):
"""Base class for the Instance Segmentation models used in OTX."""

def __init__(
Expand All @@ -69,7 +63,7 @@ def __init__(
)
self._tile_config = TileConfig()

def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchPredEntity:
def forward_tiles(self, inputs: OTXTileBatchDataEntity[InstanceSegBatchDataEntity]) -> InstanceSegBatchPredEntity:
"""Unpack instance segmentation tiles.
Args:
Expand Down Expand Up @@ -235,10 +229,7 @@ def __init__(
self.model.feature_vector_fn = get_feature_vector
self.model.explain_fn = self.get_explain_fn()

def forward_explain(
self,
inputs: InstanceSegBatchDataEntity | TileBatchInstSegDataEntity,
) -> InstanceSegBatchPredEntity:
def forward_explain(self, inputs: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity:
"""Model forward function."""
if isinstance(inputs, OTXTileBatchDataEntity):
return self.forward_tiles(inputs)
Expand Down
3 changes: 1 addition & 2 deletions src/otx/core/model/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity
from otx.core.data.entity.tile import T_OTXTileBatchDataEntity
from otx.core.metrics import MetricInput
from otx.core.metrics.dice import SegmCallable
from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel
Expand All @@ -31,7 +30,7 @@
from otx.core.metrics import MetricCallable


class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity, T_OTXTileBatchDataEntity]):
class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity]):
"""Base class for the detection models used in OTX."""

def __init__(
Expand Down
11 changes: 2 additions & 9 deletions src/otx/core/model/visual_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from torchvision import tv_tensors

from otx.core.data.entity.base import OTXBatchLossEntity, Points
from otx.core.data.entity.tile import T_OTXTileBatchDataEntity
from otx.core.data.entity.visual_prompting import (
VisualPromptingBatchDataEntity,
VisualPromptingBatchPredEntity,
Expand Down Expand Up @@ -169,9 +168,7 @@ def _inference_step_for_zero_shot(
)


class OTXVisualPromptingModel(
OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, T_OTXTileBatchDataEntity],
):
class OTXVisualPromptingModel(OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity]):
"""Base class for the visual prompting models used in OTX."""

def __init__(
Expand Down Expand Up @@ -276,11 +273,7 @@ def _set_label_info(self, _: LabelInfo | list[str]) -> None:


class OTXZeroShotVisualPromptingModel(
OTXModel[
ZeroShotVisualPromptingBatchDataEntity,
ZeroShotVisualPromptingBatchPredEntity,
T_OTXTileBatchDataEntity,
],
OTXModel[ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity],
):
"""Base class for the visual prompting models used in OTX."""

Expand Down

0 comments on commit 6cce900

Please sign in to comment.