Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support of rendering segmentation image #22

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 44 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ rerun-sdk = "0.17.0"
pyquaternion = "^0.9.9"
matplotlib = "^3.9.2"
shapely = "<2.0.0"
pycocotools = "^2.0.8"

[tool.poetry.group.dev.dependencies]
pytest = "^8.2.2"
Expand Down
57 changes: 51 additions & 6 deletions t4_devkit/schema/tables/object_ann.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,51 @@
from __future__ import annotations

import base64
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

from pycocotools import mask as cocomask
from typing_extensions import Self

from ..name import SchemaName
from .base import SchemaBase
from .registry import SCHEMAS
from ..name import SchemaName

if TYPE_CHECKING:
from t4_devkit.typing import MaskType, RoiType
from t4_devkit.typing import NDArrayU8, RoiType

__all__ = ("ObjectAnn",)
__all__ = ("ObjectAnn", "RLEMask")


@dataclass
class RLEMask:
"""A dataclass to represent segmentation mask compressed by RLE.

Attributes:
size (list[int, int]): Size of image ordering (width, height).
counts (str): RLE compressed mask data.
"""

size: list[int, int]
counts: str

@property
def width(self) -> int:
return self.size[0]

@property
def height(self) -> int:
return self.size[1]

def decode(self) -> NDArrayU8:
"""Decode segmentation mask.

Returns:
Decoded mask in shape of (H, W).
"""
counts = base64.b64decode(self.counts)
data = {"counts": counts, "size": self.size}
return cocomask.decode(data).T


@dataclass
Expand All @@ -27,7 +60,7 @@ class ObjectAnn(SchemaBase):
category_token (str): Foreign key pointing to the object category.
attribute_tokens (list[str]): Foreign keys. List of attributes for this annotation.
bbox (RoiType): Annotated bounding box. Given as [xmin, ymin, xmax, ymax].
mask (MaskType): Instance mask using the COCO format.
mask (RLEMask): Instance mask using the COCO format compressed by RLE.
"""

token: str
Expand All @@ -36,7 +69,7 @@ class ObjectAnn(SchemaBase):
category_token: str
attribute_tokens: list[str]
bbox: RoiType
mask: MaskType
mask: RLEMask

# shortcuts
category_name: str = field(init=False)
Expand All @@ -47,12 +80,24 @@ def shortcuts() -> tuple[str]:

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
return cls(**data)
new_data = data.copy()
new_data["mask"] = RLEMask(**data["mask"])
return cls(**new_data)

@property
def width(self) -> int:
"""Return the width of the bounding box.

Returns:
Bounding box width in pixel.
"""
return self.bbox[2] - self.bbox[0]

@property
def height(self) -> int:
"""Return the height of the bounding box.

Returns:
Bounding box height in pixel.
"""
return self.bbox[3] - self.bbox[1]
11 changes: 7 additions & 4 deletions t4_devkit/schema/tables/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

from typing_extensions import Self

from ..name import SchemaName
from .base import SchemaBase
from .registry import SCHEMAS
from ..name import SchemaName

__all__ = ("Sample",)

Expand All @@ -30,7 +30,9 @@ class Sample(SchemaBase):
This should be set after instantiated.
ann_3ds (list[str]): List of foreign keys pointing the sample annotations.
This should be set after instantiated.
ann_3ds (list[str]): List of foreign keys pointing the object annotations.
ann_2ds (list[str]): List of foreign keys pointing the object annotations.
This should be set after instantiated.
surface_anns (list[str]): List of foreign keys pointing the surface annotations.
This should be set after instantiated.
"""

Expand All @@ -44,10 +46,11 @@ class Sample(SchemaBase):
data: dict[str, str] = field(default_factory=dict, init=False)
ann_3ds: list[str] = field(default_factory=list, init=False)
ann_2ds: list[str] = field(default_factory=list, init=False)
surface_anns: list[str] = field(default_factory=list, init=False)

@staticmethod
def shortcuts() -> tuple[str, str, str]:
return ("data", "ann_3ds", "ann_2ds")
def shortcuts() -> tuple[str, str, str, str]:
return ("data", "ann_3ds", "ann_2ds", "surface_ann_2ds")

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
Expand Down
36 changes: 30 additions & 6 deletions t4_devkit/schema/tables/surface_ann.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from __future__ import annotations

from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any

import numpy as np
from typing_extensions import Self

from ..name import SchemaName
from .base import SchemaBase
from .object_ann import RLEMask
from .registry import SCHEMAS
from ..name import SchemaName

if TYPE_CHECKING:
from t4_devkit.typing import MaskType
from t4_devkit.typing import RoiType

__all__ = ("SurfaceAnn",)

Expand All @@ -24,14 +26,36 @@ class SurfaceAnn(SchemaBase):
token (str): Unique record identifier.
sample_data_token (str): Foreign key pointing to the sample data, which must be a keyframe image.
category_token (str): Foreign key pointing to the surface category.
mask (MaskType): Segmentation mask using the COCO format.
mask (RLEMask): Segmentation mask using the COCO format compressed by RLE.
"""

token: str
sample_data_token: str
category_token: str
mask: MaskType
mask: RLEMask

# shortcuts
category_name: str = field(init=False)

@staticmethod
def shortcuts() -> tuple[str]:
return ("category_name",)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Self:
return cls(**data)
new_data = data.copy()
new_data["mask"] = RLEMask(**data["mask"])
return cls(**new_data)

@property
def bbox(self) -> RoiType:
"""Return a bounding box corners calculated from polygon vertices.

Returns:
Given as [xmin, ymin, xmax, ymax].
"""
mask = self.mask.decode()
indices = np.where(mask == 1)
xmin, ymin = np.min(indices, axis=1)
xmax, ymax = np.max(indices, axis=1)
return xmin, ymin, xmax, ymax
93 changes: 83 additions & 10 deletions t4_devkit/tier4.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def __make_reverse_index__(self, verbose: bool) -> None:
category: Category = self.get("category", instance.category_token)
record.category_name = category.name

for record in self.surface_ann:
category: Category = self.get("category", record.category_token)
record.category_name = category.name

registered_channels: list[str] = []
for record in self.sample_data:
cs_record: CalibratedSensor = self.get(
Expand Down Expand Up @@ -216,6 +220,11 @@ def __make_reverse_index__(self, verbose: bool) -> None:
sample_record: Sample = self.get("sample", sd_record.sample_token)
sample_record.ann_2ds.append(ann_record.token)

for ann_record in self.surface_ann:
sd_record: SampleData = self.get("sample_data", ann_record.sample_data_token)
sample_record: Sample = self.get("sample", sd_record.sample_token)
sample_record.surface_anns.append(ann_record.token)

log_to_map: dict[str, str] = {}
for map_record in self.map:
for log_token in map_record.log_tokens:
Expand Down Expand Up @@ -1127,18 +1136,48 @@ def _render_annotation_2ds(
if max_timestamp_us < sample.timestamp:
break

if instance_token is not None:
boxes = []
for ann_token in sample.ann_2ds:
ann: ObjectAnn = self.get("object_ann", ann_token)
if ann.instance_token == instance_token:
boxes.append(self.get_box2d(ann_token))
break
else:
boxes = list(map(self.get_box2d, sample.ann_2ds))
boxes: list[Box2D] = []

# For segmentation masks
# TODO: declare specific class for segmentation mask in `dataclass`
camera_masks: dict[str, dict[str, list]] = {}

# Object Annotation
for ann_token in sample.ann_2ds:
ann: ObjectAnn = self.get("object_ann", ann_token)
box = self.get_box2d(ann_token)
boxes.append(box)

sample_data: SampleData = self.get("sample_data", ann.sample_data_token)
camera_masks = _append_mask(
camera_masks,
camera=sample_data.channel,
ann=ann,
class_id=self._label2id[ann.category_name],
uuid=box.uuid,
)

# Render 2D box
viewer.render_box2ds(us2sec(sample.timestamp), boxes)

# TODO: add support of rendering object/surface mask and keypoints
# Surface Annotation
for ann_token in sample.surface_anns:
sample_data: SampleData = self.get("sample_data", ann.sample_data_token)
ann: SurfaceAnn = self.get("surface_ann", ann_token)
camera_masks = _append_mask(
camera_masks,
camera=sample_data.channel,
ann=ann,
class_id=self._label2id[ann.category_name],
)

# Render 2D segmentation image
for camera, data in camera_masks.items():
viewer.render_segmentation2d(
seconds=us2sec(sample.timestamp), camera=camera, **data
)

# TODO: add support of rendering keypoints
current_sample_token = sample.next

def _render_sensor_calibration(self, viewer: Tier4Viewer, sample_data_token: str) -> None:
Expand All @@ -1154,3 +1193,37 @@ def _render_sensor_calibration(self, viewer: Tier4Viewer, sample_data_token: str
)
sensor: Sensor = self.get("sensor", calibrated_sensor.sensor_token)
viewer.render_calibration(sensor, calibrated_sensor)


def _append_mask(
camera_masks: dict[str, dict[str, list]],
camera: str,
ann: ObjectAnn | SurfaceAnn,
class_id: int,
uuid: str | None = None,
) -> dict[str, dict[str, list]]:
"""Append segmentation mask data from `ObjectAnn/SurfaceAnn`.

TODO:
This function should be removed after declaring specific dataclass for 2D segmentation.

Args:
camera_masks (dict[str, dict[str, list]]): Key-value data mapping camera name and mask data.
camera (str): Name of camera channel.
ann (ObjectAnn | SurfaceAnn): Annotation object.
class_id (int): Class ID.
uuid (str | None, optional): Unique instance identifier.

Returns:
dict[str, dict[str, list]]: Updated `camera_masks`.
"""
if camera in camera_masks:
camera_masks[camera]["masks"].append(ann.mask.decode())
camera_masks[camera]["class_ids"].append(class_id)
camera_masks[camera]["uuids"].append(uuid)
else:
camera_masks[camera] = {}
camera_masks[camera]["masks"] = [ann.mask.decode()]
camera_masks[camera]["class_ids"] = [class_id]
camera_masks[camera]["uuids"] = [class_id]
return camera_masks
2 changes: 0 additions & 2 deletions t4_devkit/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"CamIntrinsicType",
"CamDistortionType",
"RoiType",
"MaskType",
"KeypointType",
)

Expand Down Expand Up @@ -54,5 +53,4 @@

# 2D
RoiType = NewType("RoiType", tuple[int, int, int, int]) # (xmin, ymin, xmax, ymax)
MaskType = NewType("MaskType", list[int])
KeypointType = NewType("KeypointType", NDArrayF64)
Loading
Loading