Skip to content

Commit

Permalink
Add Keypoint Detection legacy template (#4094)
Browse files Browse the repository at this point in the history
added rtmpose_template
  • Loading branch information
kprokofi authored Nov 5, 2024
1 parent dc882bf commit 15746ea
Show file tree
Hide file tree
Showing 4 changed files with 541 additions and 6 deletions.
36 changes: 35 additions & 1 deletion src/otx/core/model/keypoint_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch

from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.core.data.entity.base import OTXBatchLossEntity
from otx.core.data.entity.base import ImageInfo, OTXBatchLossEntity
from otx.core.data.entity.keypoint_detection import KeypointDetBatchDataEntity, KeypointDetBatchPredEntity
from otx.core.metrics import MetricCallable, MetricInput
from otx.core.metrics.pck import PCKMeasureCallable
Expand Down Expand Up @@ -150,6 +150,40 @@ def forward_for_tracing(self, image: torch.Tensor) -> torch.Tensor | tuple[torch
"""Model forward function used for the model tracing during model exportation."""
return self.model.forward(inputs=image, mode="tensor")

def get_dummy_input(self, batch_size: int = 1) -> KeypointDetBatchDataEntity:
"""Generates a dummy input, suitable for launching forward() on it.
Args:
batch_size (int, optional): number of elements in a dummy input sequence. Defaults to 1.
Returns:
KeypointDetBatchDataEntity: An entity containing randomly generated inference data.
"""
if self.input_size is None:
msg = f"Input size attribute is not set for {self.__class__}"
raise ValueError(msg)

images = torch.rand(batch_size, 3, *self.input_size)
infos = []
for i, img in enumerate(images):
infos.append(
ImageInfo(
img_idx=i,
img_shape=img.shape,
ori_shape=img.shape,
),
)
return KeypointDetBatchDataEntity(
batch_size,
images,
infos,
bboxes=[],
labels=[],
bbox_info=[],
keypoints=[],
keypoints_visible=[],
)

@property
def _export_parameters(self) -> TaskLevelExportParameters:
"""Defines parameters required to export a particular model implementation."""
Expand Down
6 changes: 1 addition & 5 deletions src/otx/tools/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,10 @@
"model_name": "stfpm",
},
# KEYPOINT_DETECTION
"Custom_Keypoint_Detection_Rtmpose_T": {
"Keypoint_Detection_RTMPose_Tiny": {
"task": OTXTaskType.KEYPOINT_DETECTION,
"model_name": "rtmpose_tiny",
},
"Custom_Keypoint_Detection_Rtmpose_T_Single_Obj": {
"task": OTXTaskType.KEYPOINT_DETECTION,
"model_name": "rtmpose_tiny_single_obj",
},
}


Expand Down
Loading

0 comments on commit 15746ea

Please sign in to comment.