diff --git a/src/otx/core/data/entity/tile.py b/src/otx/core/data/entity/tile.py index e407328ed53..911cd8b70db 100644 --- a/src/otx/core/data/entity/tile.py +++ b/src/otx/core/data/entity/tile.py @@ -6,7 +6,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Generic, Sequence, TypeVar +from typing import TYPE_CHECKING, Generic, Sequence from otx.core.types.task import OTXTaskType @@ -20,12 +20,6 @@ from torchvision import tv_tensors -T_OTXTileBatchDataEntity = TypeVar( - "T_OTXTileBatchDataEntity", - bound="OTXTileBatchDataEntity", -) - - @dataclass class TileDataEntity(Generic[T_OTXDataEntity]): """Base data entity for tile task. @@ -66,6 +60,9 @@ def task(self) -> OTXTaskType: return OTXTaskType.DETECTION +TileAttrDictList = list[dict[str, int | str]] + + @dataclass class OTXTileBatchDataEntity(Generic[T_OTXBatchDataEntity]): """Base batch data entity for tile task. @@ -82,10 +79,10 @@ class OTXTileBatchDataEntity(Generic[T_OTXBatchDataEntity]): batch_size: int batch_tiles: list[list[tv_tensors.Image]] batch_tile_img_infos: list[list[ImageInfo]] - batch_tile_attr_list: list[list[dict[str, int | str]]] + batch_tile_attr_list: list[TileAttrDictList] imgs_info: list[ImageInfo] - def unbind(self) -> list[T_OTXBatchDataEntity]: + def unbind(self) -> list[tuple[TileAttrDictList, T_OTXBatchDataEntity]]: """Unbind batch data entity.""" raise NotImplementedError @@ -102,7 +99,7 @@ class TileBatchDetDataEntity(OTXTileBatchDataEntity): bboxes: list[tv_tensors.BoundingBoxes] labels: list[LongTensor] - def unbind(self) -> list[tuple[list[dict[str, int | str]], DetBatchDataEntity]]: + def unbind(self) -> list[tuple[TileAttrDictList, DetBatchDataEntity]]: """Unbind batch data entity for detection task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] @@ -194,7 +191,7 @@ class TileBatchInstSegDataEntity(OTXTileBatchDataEntity): masks: list[tv_tensors.Mask] polygons: list[list[Polygon]] - def unbind(self) -> list[tuple[list[dict[str, int | str]], InstanceSegBatchDataEntity]]: + def unbind(self) -> list[tuple[TileAttrDictList, InstanceSegBatchDataEntity]]: """Unbind batch data entity for instance segmentation task.""" tiles = [tile for tiles in self.batch_tiles for tile in tiles] tile_infos = [tile_info for tile_infos in self.batch_tile_img_infos for tile_info in tile_infos] diff --git a/src/otx/core/model/action_classification.py b/src/otx/core/model/action_classification.py index 0c662f0fea0..28b37a50061 100644 --- a/src/otx/core/model/action_classification.py +++ b/src/otx/core/model/action_classification.py @@ -12,7 +12,6 @@ from otx.core.data.entity.action_classification import ActionClsBatchDataEntity, ActionClsBatchPredEntity from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput from otx.core.metrics.accuracy import MultiClassClsMetricCallable @@ -32,13 +31,7 @@ from otx.core.metrics import MetricCallable -class OTXActionClsModel( - OTXModel[ - ActionClsBatchDataEntity, - ActionClsBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXActionClsModel(OTXModel[ActionClsBatchDataEntity, ActionClsBatchPredEntity]): """Base class for the action classification models used in OTX.""" def __init__( diff --git a/src/otx/core/model/action_detection.py b/src/otx/core/model/action_detection.py index 7529f93602c..b978110ba50 100644 --- a/src/otx/core/model/action_detection.py +++ b/src/otx/core/model/action_detection.py @@ -11,7 +11,6 @@ from otx.core.data.entity.action_detection import ActionDetBatchDataEntity, ActionDetBatchPredEntity from otx.core.data.entity.base import OTXBatchLossEntity -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.metrics import MetricInput from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel @@ -26,13 +25,7 @@ from otx.core.metrics import MetricCallable -class OTXActionDetModel( - OTXModel[ - ActionDetBatchDataEntity, - ActionDetBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXActionDetModel(OTXModel[ActionDetBatchDataEntity, ActionDetBatchPredEntity]): """Base class for the action detection models used in OTX.""" def __init__( diff --git a/src/otx/core/model/base.py b/src/otx/core/model/base.py index a4bf4943e0e..aa0c56fbbb1 100644 --- a/src/otx/core/model/base.py +++ b/src/otx/core/model/base.py @@ -34,7 +34,7 @@ T_OTXBatchDataEntity, T_OTXBatchPredEntity, ) -from otx.core.data.entity.tile import OTXTileBatchDataEntity, T_OTXTileBatchDataEntity +from otx.core.data.entity.tile import OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput, NullMetricCallable @@ -81,7 +81,7 @@ def _default_scheduler_callable( DefaultSchedulerCallable = _default_scheduler_callable -class OTXModel(LightningModule, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity, T_OTXTileBatchDataEntity]): +class OTXModel(LightningModule, Generic[T_OTXBatchDataEntity, T_OTXBatchPredEntity]): """Base class for the models used in OTX. Args: @@ -515,7 +515,7 @@ def _restore_model_forward(self) -> None: def forward_tiles( self, - inputs: T_OTXTileBatchDataEntity, + inputs: OTXTileBatchDataEntity[T_OTXBatchDataEntity], ) -> T_OTXBatchPredEntity | OTXBatchLossEntity: """Model forward function for tile task.""" raise NotImplementedError diff --git a/src/otx/core/model/classification.py b/src/otx/core/model/classification.py index a88c2252511..d724d76ff16 100644 --- a/src/otx/core/model/classification.py +++ b/src/otx/core/model/classification.py @@ -20,7 +20,6 @@ MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity, ) -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter from otx.core.metrics import MetricInput @@ -46,13 +45,7 @@ from otx.core.metrics import MetricCallable -class OTXMulticlassClsModel( - OTXModel[ - MulticlassClsBatchDataEntity, - MulticlassClsBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXMulticlassClsModel(OTXModel[MulticlassClsBatchDataEntity, MulticlassClsBatchPredEntity]): """Base class for the classification models used in OTX.""" def __init__( @@ -241,13 +234,7 @@ def _exporter(self) -> OTXModelExporter: ### It'll be integrated after H-label classification integration with more advanced design. -class OTXMultilabelClsModel( - OTXModel[ - MultilabelClsBatchDataEntity, - MultilabelClsBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXMultilabelClsModel(OTXModel[MultilabelClsBatchDataEntity, MultilabelClsBatchPredEntity]): """Multi-label classification models used in OTX.""" def __init__( @@ -433,13 +420,7 @@ def _exporter(self) -> OTXModelExporter: ) -class OTXHlabelClsModel( - OTXModel[ - HlabelClsBatchDataEntity, - HlabelClsBatchPredEntity, - T_OTXTileBatchDataEntity, - ], -): +class OTXHlabelClsModel(OTXModel[HlabelClsBatchDataEntity, HlabelClsBatchPredEntity]): """H-label classification models used in OTX.""" def __init__( diff --git a/src/otx/core/model/detection.py b/src/otx/core/model/detection.py index 08e16c89752..53e594810aa 100644 --- a/src/otx/core/model/detection.py +++ b/src/otx/core/model/detection.py @@ -17,7 +17,7 @@ from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity -from otx.core.data.entity.tile import OTXTileBatchDataEntity, TileBatchDetDataEntity +from otx.core.data.entity.tile import OTXTileBatchDataEntity from otx.core.metrics import MetricCallable, MetricInput from otx.core.metrics.mean_ap import MeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel @@ -37,7 +37,7 @@ from torchmetrics import Metric -class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity, TileBatchDetDataEntity]): +class OTXDetectionModel(OTXModel[DetBatchDataEntity, DetBatchPredEntity]): """Base class for the detection models used in OTX.""" def __init__( @@ -57,7 +57,7 @@ def __init__( ) self._tile_config = TileConfig() - def forward_tiles(self, inputs: TileBatchDetDataEntity) -> DetBatchPredEntity: + def forward_tiles(self, inputs: OTXTileBatchDataEntity[DetBatchDataEntity]) -> DetBatchPredEntity: """Unpack detection tiles. Args: @@ -193,10 +193,7 @@ def __init__( self.model.feature_vector_fn = get_feature_vector self.model.explain_fn = self.get_explain_fn() - def forward_explain( - self, - inputs: DetBatchDataEntity | TileBatchDetDataEntity, - ) -> DetBatchPredEntity: + def forward_explain(self, inputs: DetBatchDataEntity) -> DetBatchPredEntity: """Model forward function.""" from otx.algo.explain.explain_algo import get_feature_vector diff --git a/src/otx/core/model/instance_segmentation.py b/src/otx/core/model/instance_segmentation.py index 10386b37f0c..121c6ca585e 100644 --- a/src/otx/core/model/instance_segmentation.py +++ b/src/otx/core/model/instance_segmentation.py @@ -20,7 +20,7 @@ from otx.core.config.data import TileConfig from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity, InstanceSegBatchPredEntity -from otx.core.data.entity.tile import OTXTileBatchDataEntity, TileBatchInstSegDataEntity +from otx.core.data.entity.tile import OTXTileBatchDataEntity from otx.core.metrics import MetricInput from otx.core.metrics.mean_ap import MaskRLEMeanAPCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel @@ -33,7 +33,7 @@ if TYPE_CHECKING: from lightning.pytorch.cli import LRSchedulerCallable, OptimizerCallable from mmdet.models.data_preprocessors import DetDataPreprocessor - from mmdet.models.detectors.base import TwoStageDetector + from mmdet.models.detectors.two_stage import TwoStageDetector from mmdet.structures import OptSampleList from omegaconf import DictConfig from openvino.model_api.models.utils import InstanceSegmentationResult @@ -43,13 +43,7 @@ from otx.core.metrics import MetricCallable -class OTXInstanceSegModel( - OTXModel[ - InstanceSegBatchDataEntity, - InstanceSegBatchPredEntity, - TileBatchInstSegDataEntity, - ], -): +class OTXInstanceSegModel(OTXModel[InstanceSegBatchDataEntity, InstanceSegBatchPredEntity]): """Base class for the Instance Segmentation models used in OTX.""" def __init__( @@ -69,7 +63,7 @@ def __init__( ) self._tile_config = TileConfig() - def forward_tiles(self, inputs: TileBatchInstSegDataEntity) -> InstanceSegBatchPredEntity: + def forward_tiles(self, inputs: OTXTileBatchDataEntity[InstanceSegBatchDataEntity]) -> InstanceSegBatchPredEntity: """Unpack instance segmentation tiles. Args: @@ -235,10 +229,7 @@ def __init__( self.model.feature_vector_fn = get_feature_vector self.model.explain_fn = self.get_explain_fn() - def forward_explain( - self, - inputs: InstanceSegBatchDataEntity | TileBatchInstSegDataEntity, - ) -> InstanceSegBatchPredEntity: + def forward_explain(self, inputs: InstanceSegBatchDataEntity) -> InstanceSegBatchPredEntity: """Model forward function.""" if isinstance(inputs, OTXTileBatchDataEntity): return self.forward_tiles(inputs) diff --git a/src/otx/core/model/segmentation.py b/src/otx/core/model/segmentation.py index 3beb136a2fe..e5a1f140ecf 100644 --- a/src/otx/core/model/segmentation.py +++ b/src/otx/core/model/segmentation.py @@ -12,7 +12,6 @@ from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.segmentation import SegBatchDataEntity, SegBatchPredEntity -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.metrics import MetricInput from otx.core.metrics.dice import SegmCallable from otx.core.model.base import DefaultOptimizerCallable, DefaultSchedulerCallable, OTXModel, OVModel @@ -31,7 +30,7 @@ from otx.core.metrics import MetricCallable -class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity, T_OTXTileBatchDataEntity]): +class OTXSegmentationModel(OTXModel[SegBatchDataEntity, SegBatchPredEntity]): """Base class for the detection models used in OTX.""" def __init__( diff --git a/src/otx/core/model/visual_prompting.py b/src/otx/core/model/visual_prompting.py index 2307474a1ef..081cc916554 100644 --- a/src/otx/core/model/visual_prompting.py +++ b/src/otx/core/model/visual_prompting.py @@ -26,7 +26,6 @@ from torchvision import tv_tensors from otx.core.data.entity.base import OTXBatchLossEntity, Points -from otx.core.data.entity.tile import T_OTXTileBatchDataEntity from otx.core.data.entity.visual_prompting import ( VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, @@ -169,9 +168,7 @@ def _inference_step_for_zero_shot( ) -class OTXVisualPromptingModel( - OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity, T_OTXTileBatchDataEntity], -): +class OTXVisualPromptingModel(OTXModel[VisualPromptingBatchDataEntity, VisualPromptingBatchPredEntity]): """Base class for the visual prompting models used in OTX.""" def __init__( @@ -276,11 +273,7 @@ def _set_label_info(self, _: LabelInfo | list[str]) -> None: class OTXZeroShotVisualPromptingModel( - OTXModel[ - ZeroShotVisualPromptingBatchDataEntity, - ZeroShotVisualPromptingBatchPredEntity, - T_OTXTileBatchDataEntity, - ], + OTXModel[ZeroShotVisualPromptingBatchDataEntity, ZeroShotVisualPromptingBatchPredEntity], ): """Base class for the visual prompting models used in OTX."""