Skip to content

Commit

Permalink
chore: remove duplicated SegmentationData2D (#24)
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 authored Nov 1, 2024
1 parent 4bde9e8 commit 042d7bf
Showing 1 changed file with 1 addition and 37 deletions.
38 changes: 1 addition & 37 deletions t4_devkit/viewer/rendering_data/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

if TYPE_CHECKING:
from t4_devkit.dataclass import Box2D, Box3D
from t4_devkit.typing import NDArrayU8, RoiType, SizeType, TranslationType, VelocityType
from t4_devkit.typing import RoiType, SizeType, TranslationType, VelocityType

__all__ = ["BoxData3D", "BoxData2D"]

Expand Down Expand Up @@ -107,39 +107,3 @@ def as_boxes2d(self) -> rr.Boxes2D:
labels=labels,
class_ids=self._class_ids,
)


class Segmentation2D:
def __init__(self) -> None:
self._masks: list[NDArrayU8] = []
self._class_ids: list[int] = []

self._size: tuple[int, int] = None # (height, width)

def append(self, mask: NDArrayU8, class_id: int) -> None:
"""Append a segmentation mask and its class ID.
Args:
mask (NDArrayU8): Mask image, in the shape of (H, W).
class_id (int): Class ID.
Raises:
ValueError: _description_
"""
if self._size is None:
self._size = mask.shape
else:
if self._size != mask.shape:
raise ValueError(
f"All masks must be the same size. Expected: {self._size}, but got {mask.shape}"
)
self._masks.append(mask)
self._class_ids.append(class_id)

def as_segmentation_image(self) -> rr.SegmentationImage:
image = np.zeros(self._size, dtype=np.uint8)

for mask, class_id in zip(self._masks, self._class_ids, strict=True):
image[mask == 1] == class_id

return rr.SegmentationImage(data=image)

0 comments on commit 042d7bf

Please sign in to comment.