Skip to content

Commit

Permalink
refactor: allow any label name
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Nov 14, 2024
1 parent 0ad78fa commit f5257aa
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 326 deletions.
4 changes: 2 additions & 2 deletions t4_devkit/dataclass/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class Box3D(BaseBox):
>>> box3d = Box3D(
... unix_time=100,
... frame_id="base_link",
... semantic_label=SemanticLabel(LabelID.CAR),
... semantic_label=SemanticLabel("car"),
... position=(1.0, 1.0, 1.0),
... rotation=Quaternion([0.0, 0.0, 0.0, 1.0]),
... shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)),
Expand Down Expand Up @@ -213,7 +213,7 @@ class Box2D(BaseBox):
>>> box2d = Box2D(
... unix_time=100,
... frame_id="camera",
... semantic_label=SemanticLabel(LabelID.CAR),
... semantic_label=SemanticLabel("car"),
... roi=(100, 100, 50, 50),
... confidence=1.0,
... uuid="car2d_0",
Expand Down
195 changes: 4 additions & 191 deletions t4_devkit/dataclass/label.py
Original file line number Diff line number Diff line change
@@ -1,208 +1,21 @@
from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from enum import Enum, auto, unique

from typing_extensions import Self

__all__ = ["LabelID", "SemanticLabel", "convert_label"]


@unique
class LabelID(Enum):
"""Enum of label elements."""

# catch all labels
UNKNOWN = 0

# object labels
CAR = auto()
TRUCK = auto()
BUS = auto()
BICYCLE = auto()
MOTORBIKE = auto()
PEDESTRIAN = auto()
ANIMAL = auto()

# traffic-light labels
TRAFFIC_LIGHT = auto()
GREEN = auto()
GREEN_STRAIGHT = auto()
GREEN_LEFT = auto()
GREEN_RIGHT = auto()
YELLOW = auto()
YELLOW_STRAIGHT = auto()
YELLOW_LEFT = auto()
YELLOW_RIGHT = auto()
YELLOW_STRAIGHT_LEFT = auto()
YELLOW_STRAIGHT_RIGHT = auto()
YELLOW_STRAIGHT_LEFT_RIGHT = auto()
RED = auto()
RED_STRAIGHT = auto()
RED_LEFT = auto()
RED_RIGHT = auto()
RED_STRAIGHT_LEFT = auto()
RED_STRAIGHT_RIGHT = auto()
RED_STRAIGHT_LEFT_RIGHT = auto()
RED_LEFT_DIAGONAL = auto()
RED_RIGHT_DIAGONAL = auto()

@classmethod
def from_name(cls, name: str) -> Self:
name = name.upper()
assert name in cls.__members__, f"Unexpected label name: {name}"
return cls.__members__[name]

def __eq__(self, other: str | LabelID) -> bool:
return self.name == other.upper() if isinstance(other, str) else self.name == other.name
__all__ = ["SemanticLabel"]


@dataclass(frozen=True, eq=False)
class SemanticLabel:
"""A dataclass to represent semantic labels.
Attributes:
label (LabelID): Label ID.
original (str | None, optional): Original name of the label.
name (str): Label name.
attributes (list[str], optional): List of attribute names.
"""

label: LabelID
original: str | None = field(default=None)
name: str
attributes: list[str] = field(default_factory=list)

def __eq__(self, other: str | SemanticLabel) -> bool:
return self.label == other if isinstance(other, str) else self.label == other.label


# =====================
# Label conversion
# =====================

# Name mapping (key: value) = (original: Label enum)
DEFAULT_NAME_MAPPING: dict[str, str] = {
# === ObjectLabel ===
# CAR
"car": "CAR",
"vehicle.car": "CAR",
"vehicle.construction": "CAR",
"vehicle.emergency (ambulance & police)": "CAR",
"vehicle.police": "CAR",
# TRUCK
"truck": "TRUCK",
"vehicle.truck": "TRUCK",
"trailer": "TRUCK",
"vehicle.trailer": "TRUCK",
# BUS
"bus": "BUS",
"vehicle.bus": "BUS",
"vehicle.bus (bendy & rigid)": "BUS",
# BICYCLE
"bicycle": "BICYCLE",
"vehicle.bicycle": "BICYCLE",
# MOTORBIKE
"motorbike": "MOTORBIKE",
"vehicle.motorbike": "MOTORBIKE",
"motorcycle": "MOTORBIKE",
"vehicle.motorcycle": "MOTORBIKE",
# PEDESTRIAN
"pedestrian": "PEDESTRIAN",
"pedestrian.child": "PEDESTRIAN",
"pedestrian.personal_mobility": "PEDESTRIAN",
"pedestrian.police_officer": "PEDESTRIAN",
"pedestrian.stroller": "PEDESTRIAN",
"pedestrian.wheelchair": "PEDESTRIAN",
"construction_worker": "PEDESTRIAN",
# ANIMAL
"animal": "ANIMAL",
# UNKNOWN
"movable_object.barrier": "UNKNOWN",
"movable_object.debris": "UNKNOWN",
"movable_object.pushable_pullable": "UNKNOWN",
"movable_object.trafficcone": "UNKNOWN",
"movable_object.traffic_cone": "UNKNOWN",
"static_object.bicycle_lack": "UNKNOWN",
"static_object.bollard": "UNKNOWN",
"forklift": "UNKNOWN",
# === TrafficLightLabel ===
# GREEN
"green": "GREEN",
"green_straight": "GREEN_STRAIGHT",
"green_left": "GREEN_LEFT",
"green_right": "GREEN_RIGHT",
# YELLOW
"yellow": "YELLOW",
"yellow_straight": "YELLOW_STRAIGHT",
"yellow_left": "YELLOW_LEFT",
"yellow_right": "YELLOW_RIGHT",
"yellow_straight_left": "YELLOW_STRAIGHT_LEFT",
"yellow_straight_right": "YELLOW_STRAIGHT_RIGHT",
"yellow_straight_left_right": "YELLOW_STRAIGHT_LEFT_RIGHT",
# RED
"red": "RED",
"red_straight": "RED_STRAIGHT",
"red_left": "RED_LEFT",
"red_right": "RED_RIGHT",
"red_straight_left": "RED_STRAIGHT_LEFT",
"red_straight_right": "RED_STRAIGHT_RIGHT",
"red_straight_left_right": "RED_STRAIGHT_LEFT_RIGHT",
"red_straight_left_diagonal": "RED_LEFT_DIAGONAL",
"red_straight_leftdiagonal": "RED_LEFT_DIAGONAL",
"red_straight_right_diagonal": "RED_RIGHT_DIAGONAL",
"red_straight_rightdiagonal": "RED_RIGHT_DIAGONAL",
# CROSSWALK
"crosswalk_red": "RED",
"crosswalk_green": "GREEN",
"crosswalk_unknown": "UNKNOWN",
"unknown": "UNKNOWN",
}


def convert_label(
original: str,
attributes: list[str] | None = None,
*,
name_mapping: dict[str, str] | None = None,
update_default_mapping: bool = False,
) -> SemanticLabel:
"""Covert string original label name to `SemanticLabel` object.
Args:
original (str): Original label name. For example, `vehicle.car`.
attributes (list[str] | None, optional): List of label attributes.
name_mapping (dict[str, str] | None, optional): Name mapping for original and label.
If `None`, `DEFAULT_NAME_MAPPING` will be used.
update_default_mapping (bool, optional): Whether to update `DEFAULT_NAME_MAPPING` by
`name_mapping`. If `False` and `name_mapping` is specified,
the specified `name_mapping` is used instead of `DEFAULT_NAME_MAPPING` completely.
Note that, this parameter works only if `name_mapping` is specified.
Returns:
Converted `SemanticLabel` object.
"""
global DEFAULT_NAME_MAPPING

# set name mapping
if name_mapping is None:
name_mapping = DEFAULT_NAME_MAPPING
elif update_default_mapping:
DEFAULT_NAME_MAPPING.update(name_mapping)

# convert original to name for Label object
if original in name_mapping:
name = name_mapping[original]
else:
warnings.warn(
f"{original} is not included in mapping, use UNKNOWN.",
UserWarning,
)
name = "UNKNOWN"

label = LabelID.from_name(name)

return (
SemanticLabel(label, original)
if attributes is None
else SemanticLabel(label, original, attributes)
)
return self.name == other if isinstance(other, str) else self.name == other.name
22 changes: 2 additions & 20 deletions t4_devkit/tier4.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,7 @@

from t4_devkit.common.geometry import is_box_in_image, view_points
from t4_devkit.common.timestamp import sec2us, us2sec
from t4_devkit.dataclass import (
Box2D,
Box3D,
LidarPointCloud,
RadarPointCloud,
Shape,
ShapeType,
convert_label,
)
from t4_devkit.dataclass import Box2D, Box3D, LidarPointCloud, RadarPointCloud, Shape, ShapeType
from t4_devkit.schema import SchemaName, SensorModality, VisibilityLevel, build_schema
from t4_devkit.viewer import Tier4Viewer, distance_color, format_entity

Expand Down Expand Up @@ -371,17 +363,12 @@ def get_semantic_label(
self,
category_token: str,
attribute_tokens: list[str] | None = None,
name_mapping: dict[str, str] | None = None,
*,
update_default_mapping: bool = False,
) -> SemanticLabel:
"""Return a SemanticLabel instance from specified `category_token` and `attribute_tokens`.
Args:
category_token (str): Token of `Category` table.
attribute_tokens (list[str] | None, optional): List of attribute tokens.
name_mapping (dict[str, str] | None, optional): Category name mapping.
update_default_mapping (bool, optional): Whether to update default category name mapping.
Returns:
Instantiated SemanticLabel.
Expand All @@ -393,12 +380,7 @@ def get_semantic_label(
else []
)

return convert_label(
original=category.name,
attributes=attributes,
name_mapping=name_mapping,
update_default_mapping=update_default_mapping,
)
return SemanticLabel(category.name, attributes)

def get_box3d(self, sample_annotation_token: str) -> Box3D:
"""Return a Box3D class from a `sample_annotation` record.
Expand Down
14 changes: 12 additions & 2 deletions t4_devkit/viewer/rendering_data/box.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ def __init__(self) -> None:
self._uuids: list[int] = []
self._velocities: list[VelocityType] = []

self._label2id: dict[str, int] = {}

def append(self, box: Box3D) -> None:
"""Append a 3D box data.
Expand All @@ -37,7 +39,10 @@ def append(self, box: Box3D) -> None:
width, length, height = box.size
self._sizes.append((length, width, height))

self._class_ids.append(box.semantic_label.label.value)
if box.semantic_label.name not in self._label2id:
self._label2id[box.semantic_label.name] = len(self._label2id)

self._class_ids.append(self._label2id[box.semantic_label.name])

if box.uuid is not None:
self._uuids.append(box.uuid[:6])
Expand Down Expand Up @@ -81,6 +86,8 @@ def __init__(self) -> None:
self._uuids: list[str] = []
self._class_ids: list[int] = []

self._label2id: dict[str, int] = {}

def append(self, box: Box2D) -> None:
"""Append a 2D box data.
Expand All @@ -89,7 +96,10 @@ def append(self, box: Box2D) -> None:
"""
self._rois.append(box.roi.roi)

self._class_ids.append(box.semantic_label.label.value)
if box.semantic_label.name not in self._label2id:
self._label2id[box.semantic_label.name] = len(self._label2id)

self._class_ids.append(self._label2id[box.semantic_label.name])

if box.uuid is not None:
self._uuids.append(box.uuid)
Expand Down
Loading

0 comments on commit f5257aa

Please sign in to comment.