diff --git a/t4_devkit/dataclass/box.py b/t4_devkit/dataclass/box.py index f7b6729..f585765 100644 --- a/t4_devkit/dataclass/box.py +++ b/t4_devkit/dataclass/box.py @@ -94,7 +94,7 @@ class Box3D(BaseBox): >>> box3d = Box3D( ... unix_time=100, ... frame_id="base_link", - ... semantic_label=SemanticLabel(LabelID.CAR), + ... semantic_label=SemanticLabel("car"), ... position=(1.0, 1.0, 1.0), ... rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), ... shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), @@ -213,7 +213,7 @@ class Box2D(BaseBox): >>> box2d = Box2D( ... unix_time=100, ... frame_id="camera", - ... semantic_label=SemanticLabel(LabelID.CAR), + ... semantic_label=SemanticLabel("car"), ... roi=(100, 100, 50, 50), ... confidence=1.0, ... uuid="car2d_0", diff --git a/t4_devkit/dataclass/label.py b/t4_devkit/dataclass/label.py index ab5c2e2..d94179a 100644 --- a/t4_devkit/dataclass/label.py +++ b/t4_devkit/dataclass/label.py @@ -1,61 +1,8 @@ from __future__ import annotations -import warnings from dataclasses import dataclass, field -from enum import Enum, auto, unique -from typing_extensions import Self - -__all__ = ["LabelID", "SemanticLabel", "convert_label"] - - -@unique -class LabelID(Enum): - """Enum of label elements.""" - - # catch all labels - UNKNOWN = 0 - - # object labels - CAR = auto() - TRUCK = auto() - BUS = auto() - BICYCLE = auto() - MOTORBIKE = auto() - PEDESTRIAN = auto() - ANIMAL = auto() - - # traffic-light labels - TRAFFIC_LIGHT = auto() - GREEN = auto() - GREEN_STRAIGHT = auto() - GREEN_LEFT = auto() - GREEN_RIGHT = auto() - YELLOW = auto() - YELLOW_STRAIGHT = auto() - YELLOW_LEFT = auto() - YELLOW_RIGHT = auto() - YELLOW_STRAIGHT_LEFT = auto() - YELLOW_STRAIGHT_RIGHT = auto() - YELLOW_STRAIGHT_LEFT_RIGHT = auto() - RED = auto() - RED_STRAIGHT = auto() - RED_LEFT = auto() - RED_RIGHT = auto() - RED_STRAIGHT_LEFT = auto() - RED_STRAIGHT_RIGHT = auto() - RED_STRAIGHT_LEFT_RIGHT = auto() - RED_LEFT_DIAGONAL = auto() - RED_RIGHT_DIAGONAL = auto() - - @classmethod - def from_name(cls, name: str) -> Self: - name = name.upper() - assert name in cls.__members__, f"Unexpected label name: {name}" - return cls.__members__[name] - - def __eq__(self, other: str | LabelID) -> bool: - return self.name == other.upper() if isinstance(other, str) else self.name == other.name +__all__ = ["SemanticLabel"] @dataclass(frozen=True, eq=False) @@ -63,146 +10,12 @@ class SemanticLabel: """A dataclass to represent semantic labels. Attributes: - label (LabelID): Label ID. - original (str | None, optional): Original name of the label. + name (str): Label name. attributes (list[str], optional): List of attribute names. """ - label: LabelID - original: str | None = field(default=None) + name: str attributes: list[str] = field(default_factory=list) def __eq__(self, other: str | SemanticLabel) -> bool: - return self.label == other if isinstance(other, str) else self.label == other.label - - -# ===================== -# Label conversion -# ===================== - -# Name mapping (key: value) = (original: Label enum) -DEFAULT_NAME_MAPPING: dict[str, str] = { - # === ObjectLabel === - # CAR - "car": "CAR", - "vehicle.car": "CAR", - "vehicle.construction": "CAR", - "vehicle.emergency (ambulance & police)": "CAR", - "vehicle.police": "CAR", - # TRUCK - "truck": "TRUCK", - "vehicle.truck": "TRUCK", - "trailer": "TRUCK", - "vehicle.trailer": "TRUCK", - # BUS - "bus": "BUS", - "vehicle.bus": "BUS", - "vehicle.bus (bendy & rigid)": "BUS", - # BICYCLE - "bicycle": "BICYCLE", - "vehicle.bicycle": "BICYCLE", - # MOTORBIKE - "motorbike": "MOTORBIKE", - "vehicle.motorbike": "MOTORBIKE", - "motorcycle": "MOTORBIKE", - "vehicle.motorcycle": "MOTORBIKE", - # PEDESTRIAN - "pedestrian": "PEDESTRIAN", - "pedestrian.child": "PEDESTRIAN", - "pedestrian.personal_mobility": "PEDESTRIAN", - "pedestrian.police_officer": "PEDESTRIAN", - "pedestrian.stroller": "PEDESTRIAN", - "pedestrian.wheelchair": "PEDESTRIAN", - "construction_worker": "PEDESTRIAN", - # ANIMAL - "animal": "ANIMAL", - # UNKNOWN - "movable_object.barrier": "UNKNOWN", - "movable_object.debris": "UNKNOWN", - "movable_object.pushable_pullable": "UNKNOWN", - "movable_object.trafficcone": "UNKNOWN", - "movable_object.traffic_cone": "UNKNOWN", - "static_object.bicycle_lack": "UNKNOWN", - "static_object.bollard": "UNKNOWN", - "forklift": "UNKNOWN", - # === TrafficLightLabel === - # GREEN - "green": "GREEN", - "green_straight": "GREEN_STRAIGHT", - "green_left": "GREEN_LEFT", - "green_right": "GREEN_RIGHT", - # YELLOW - "yellow": "YELLOW", - "yellow_straight": "YELLOW_STRAIGHT", - "yellow_left": "YELLOW_LEFT", - "yellow_right": "YELLOW_RIGHT", - "yellow_straight_left": "YELLOW_STRAIGHT_LEFT", - "yellow_straight_right": "YELLOW_STRAIGHT_RIGHT", - "yellow_straight_left_right": "YELLOW_STRAIGHT_LEFT_RIGHT", - # RED - "red": "RED", - "red_straight": "RED_STRAIGHT", - "red_left": "RED_LEFT", - "red_right": "RED_RIGHT", - "red_straight_left": "RED_STRAIGHT_LEFT", - "red_straight_right": "RED_STRAIGHT_RIGHT", - "red_straight_left_right": "RED_STRAIGHT_LEFT_RIGHT", - "red_straight_left_diagonal": "RED_LEFT_DIAGONAL", - "red_straight_leftdiagonal": "RED_LEFT_DIAGONAL", - "red_straight_right_diagonal": "RED_RIGHT_DIAGONAL", - "red_straight_rightdiagonal": "RED_RIGHT_DIAGONAL", - # CROSSWALK - "crosswalk_red": "RED", - "crosswalk_green": "GREEN", - "crosswalk_unknown": "UNKNOWN", - "unknown": "UNKNOWN", -} - - -def convert_label( - original: str, - attributes: list[str] | None = None, - *, - name_mapping: dict[str, str] | None = None, - update_default_mapping: bool = False, -) -> SemanticLabel: - """Covert string original label name to `SemanticLabel` object. - - Args: - original (str): Original label name. For example, `vehicle.car`. - attributes (list[str] | None, optional): List of label attributes. - name_mapping (dict[str, str] | None, optional): Name mapping for original and label. - If `None`, `DEFAULT_NAME_MAPPING` will be used. - update_default_mapping (bool, optional): Whether to update `DEFAULT_NAME_MAPPING` by - `name_mapping`. If `False` and `name_mapping` is specified, - the specified `name_mapping` is used instead of `DEFAULT_NAME_MAPPING` completely. - Note that, this parameter works only if `name_mapping` is specified. - - Returns: - Converted `SemanticLabel` object. - """ - global DEFAULT_NAME_MAPPING - - # set name mapping - if name_mapping is None: - name_mapping = DEFAULT_NAME_MAPPING - elif update_default_mapping: - DEFAULT_NAME_MAPPING.update(name_mapping) - - # convert original to name for Label object - if original in name_mapping: - name = name_mapping[original] - else: - warnings.warn( - f"{original} is not included in mapping, use UNKNOWN.", - UserWarning, - ) - name = "UNKNOWN" - - label = LabelID.from_name(name) - - return ( - SemanticLabel(label, original) - if attributes is None - else SemanticLabel(label, original, attributes) - ) + return self.name == other if isinstance(other, str) else self.name == other.name diff --git a/t4_devkit/tier4.py b/t4_devkit/tier4.py index 59b1346..1bdfffa 100644 --- a/t4_devkit/tier4.py +++ b/t4_devkit/tier4.py @@ -11,15 +11,7 @@ from t4_devkit.common.geometry import is_box_in_image, view_points from t4_devkit.common.timestamp import sec2us, us2sec -from t4_devkit.dataclass import ( - Box2D, - Box3D, - LidarPointCloud, - RadarPointCloud, - Shape, - ShapeType, - convert_label, -) +from t4_devkit.dataclass import Box2D, Box3D, LidarPointCloud, RadarPointCloud, Shape, ShapeType from t4_devkit.schema import SchemaName, SensorModality, VisibilityLevel, build_schema from t4_devkit.viewer import Tier4Viewer, distance_color, format_entity @@ -371,17 +363,12 @@ def get_semantic_label( self, category_token: str, attribute_tokens: list[str] | None = None, - name_mapping: dict[str, str] | None = None, - *, - update_default_mapping: bool = False, ) -> SemanticLabel: """Return a SemanticLabel instance from specified `category_token` and `attribute_tokens`. Args: category_token (str): Token of `Category` table. attribute_tokens (list[str] | None, optional): List of attribute tokens. - name_mapping (dict[str, str] | None, optional): Category name mapping. - update_default_mapping (bool, optional): Whether to update default category name mapping. Returns: Instantiated SemanticLabel. @@ -393,12 +380,7 @@ def get_semantic_label( else [] ) - return convert_label( - original=category.name, - attributes=attributes, - name_mapping=name_mapping, - update_default_mapping=update_default_mapping, - ) + return SemanticLabel(category.name, attributes) def get_box3d(self, sample_annotation_token: str) -> Box3D: """Return a Box3D class from a `sample_annotation` record. diff --git a/t4_devkit/viewer/rendering_data/box.py b/t4_devkit/viewer/rendering_data/box.py index 3e4de22..9c9f804 100644 --- a/t4_devkit/viewer/rendering_data/box.py +++ b/t4_devkit/viewer/rendering_data/box.py @@ -23,6 +23,8 @@ def __init__(self) -> None: self._uuids: list[int] = [] self._velocities: list[VelocityType] = [] + self._label2id: dict[str, int] = {} + def append(self, box: Box3D) -> None: """Append a 3D box data. @@ -37,7 +39,10 @@ def append(self, box: Box3D) -> None: width, length, height = box.size self._sizes.append((length, width, height)) - self._class_ids.append(box.semantic_label.label.value) + if box.semantic_label.name not in self._label2id: + self._label2id[box.semantic_label.name] = len(self._label2id) + + self._class_ids.append(self._label2id[box.semantic_label.name]) if box.uuid is not None: self._uuids.append(box.uuid[:6]) @@ -81,6 +86,8 @@ def __init__(self) -> None: self._uuids: list[str] = [] self._class_ids: list[int] = [] + self._label2id: dict[str, int] = {} + def append(self, box: Box2D) -> None: """Append a 2D box data. @@ -89,7 +96,10 @@ def append(self, box: Box2D) -> None: """ self._rois.append(box.roi.roi) - self._class_ids.append(box.semantic_label.label.value) + if box.semantic_label.name not in self._label2id: + self._label2id[box.semantic_label.name] = len(self._label2id) + + self._class_ids.append(self._label2id[box.semantic_label.name]) if box.uuid is not None: self._uuids.append(box.uuid) diff --git a/tests/conftest.py b/tests/conftest.py index 985c597..6ed3cff 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,6 @@ Box2D, Box3D, HomogeneousMatrix, - LabelID, SemanticLabel, Shape, ShapeType, @@ -23,7 +22,7 @@ def dummy_box3d() -> Box3D: return Box3D( unix_time=100, frame_id="base_link", - semantic_label=SemanticLabel(LabelID.CAR), + semantic_label=SemanticLabel("car"), position=(1.0, 1.0, 1.0), rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), @@ -44,7 +43,7 @@ def dummy_box3ds() -> list[Box3D]: Box3D( unix_time=100, frame_id="base_link", - semantic_label=SemanticLabel(LabelID.CAR), + semantic_label=SemanticLabel("car"), position=(1.0, 1.0, 1.0), rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), @@ -55,7 +54,7 @@ def dummy_box3ds() -> list[Box3D]: Box3D( unix_time=100, frame_id="base_link", - semantic_label=SemanticLabel(LabelID.BICYCLE), + semantic_label=SemanticLabel("bicycle"), position=(-1.0, -1.0, 1.0), rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), @@ -66,7 +65,7 @@ def dummy_box3ds() -> list[Box3D]: Box3D( unix_time=100, frame_id="base_link", - semantic_label=SemanticLabel(LabelID.PEDESTRIAN), + semantic_label=SemanticLabel("pedestrian"), position=(-1.0, 1.0, 1.0), rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), @@ -87,7 +86,7 @@ def dummy_box2d() -> Box2D: return Box2D( unix_time=100, frame_id="camera", - semantic_label=SemanticLabel(LabelID.CAR), + semantic_label=SemanticLabel("car"), roi=(100, 100, 50, 50), confidence=1.0, uuid="car2d_0", @@ -105,7 +104,7 @@ def dummy_box2ds() -> list[Box2D]: Box2D( unix_time=100, frame_id="camera", - semantic_label=SemanticLabel(LabelID.CAR), + semantic_label=SemanticLabel("car"), roi=(100, 100, 50, 50), confidence=1.0, uuid="car2d_1", @@ -113,7 +112,7 @@ def dummy_box2ds() -> list[Box2D]: Box2D( unix_time=100, frame_id="camera", - semantic_label=SemanticLabel(LabelID.BICYCLE), + semantic_label=SemanticLabel("bicycle"), roi=(50, 50, 10, 10), confidence=1.0, uuid="bicycle2d_1", @@ -121,7 +120,7 @@ def dummy_box2ds() -> list[Box2D]: Box2D( unix_time=100, frame_id="camera", - semantic_label=SemanticLabel(LabelID.PEDESTRIAN), + semantic_label=SemanticLabel("pedestrian"), roi=(150, 150, 20, 20), confidence=1.0, uuid="pedestrian2d_1", diff --git a/tests/dataclass/test_label.py b/tests/dataclass/test_label.py deleted file mode 100644 index 3bbdd47..0000000 --- a/tests/dataclass/test_label.py +++ /dev/null @@ -1,102 +0,0 @@ -import pytest - -from t4_devkit.dataclass.label import LabelID, convert_label - - -@pytest.mark.parametrize( - ("labels", "expect"), - [ - # === object === - # car - ( - ( - "car", - "vehicle.car", - "vehicle.construction", - "vehicle.emergency (ambulance & police)", - "vehicle.police", - ), - LabelID.CAR, - ), - # truck - (("truck", "vehicle.truck", "trailer", "vehicle.trailer"), LabelID.TRUCK), - # bus - (("bus", "vehicle.bus", "vehicle.bus (bendy & rigid)"), LabelID.BUS), - # bicycle - (("bicycle", "vehicle.bicycle"), LabelID.BICYCLE), - # motorbike - ( - ("motorbike", "vehicle.motorbike", "motorcycle", "vehicle.motorcycle"), - LabelID.MOTORBIKE, - ), - # pedestrian - ( - ( - "pedestrian", - "pedestrian.child", - "pedestrian.personal_mobility", - "pedestrian.police_officer", - "pedestrian.stroller", - "pedestrian.wheelchair", - "construction_worker", - ), - LabelID.PEDESTRIAN, - ), - # animal - ("animal", LabelID.ANIMAL), - # unknown - ( - ( - "movable_object.barrier", - "movable_object.debris", - "movable_object.pushable_pullable", - "movable_object.trafficcone", - "movable_object.traffic_cone", - "static_object.bicycle_lack", - "static_object.bollard", - "forklift", - ), - LabelID.UNKNOWN, - ), - # === traffic light === - # GREEN - (("green", "crosswalk_green"), LabelID.GREEN), - ("green_straight", LabelID.GREEN_STRAIGHT), - ("green_left", LabelID.GREEN_LEFT), - ("green_right", LabelID.GREEN_RIGHT), - # YELLOW - ("yellow", LabelID.YELLOW), - ("yellow_straight", LabelID.YELLOW_STRAIGHT), - ("yellow_left", LabelID.YELLOW_LEFT), - ("yellow_right", LabelID.YELLOW_RIGHT), - ("yellow_straight_left", LabelID.YELLOW_STRAIGHT_LEFT), - ("yellow_straight_right", LabelID.YELLOW_STRAIGHT_RIGHT), - ("yellow_straight_left_right", LabelID.YELLOW_STRAIGHT_LEFT_RIGHT), - # RED - (("red", "crosswalk_red"), LabelID.RED), - ("red_straight", LabelID.RED_STRAIGHT), - ("red_left", LabelID.RED_LEFT), - ("red_right", LabelID.RED_RIGHT), - ("red_straight_left", LabelID.RED_STRAIGHT_LEFT), - ("red_straight_right", LabelID.RED_STRAIGHT_RIGHT), - ("red_straight_left_right", LabelID.RED_STRAIGHT_LEFT_RIGHT), - ( - ("red_straight_left_diagonal", "red_straight_leftdiagonal"), - LabelID.RED_LEFT_DIAGONAL, - ), - ( - ("red_straight_right_diagonal", "red_straight_rightdiagonal"), - LabelID.RED_RIGHT_DIAGONAL, - ), - # unknown traffic light - (("unknown", "crosswalk_unknown"), LabelID.UNKNOWN), - ], -) -def test_convert_label(labels: str | tuple[str, ...], expect: LabelID) -> None: - if isinstance(labels, str): - labels = [labels] - - for original in labels: - ret = convert_label(original) - assert ret.label == expect - assert ret.original == original