From 0d7736e6da6d248768b5b8eae55def390bfa28da Mon Sep 17 00:00:00 2001 From: Kotaro Uetake <60615504+ktro2828@users.noreply.github.com> Date: Fri, 18 Oct 2024 00:44:26 +0900 Subject: [PATCH] feat: add data-classes to represent objects (#10) * feat: add dataclasses Signed-off-by: ktro2828 * docs: update API documentation Signed-off-by: ktro2828 * feat: apply `dataclass.Box*` Signed-off-by: ktro2828 * feat: update `Label` enum and its test Signed-off-by: ktro2828 * feat: rename `Label` to `LabelID` Signed-off-by: ktro2828 * test: add unit testing for boxes Signed-off-by: ktro2828 * feat: rename `Label` to `LabelID` Signed-off-by: ktro2828 * test: add unit testing for boxes Signed-off-by: ktro2828 * TODO: add dataclasses of pointcloud Signed-off-by: ktro2828 * chore: move `transform.py` into `common` Signed-off-by: ktro2828 * chore: move `transform.py` into `common` Signed-off-by: ktro2828 * docs: update documents Signed-off-by: ktro2828 * feat: update dataclasses Signed-off-by: ktro2828 * style(pre-commit): autofix --------- Signed-off-by: ktro2828 Co-authored-by: ktro2828 --- .gitignore | 1 + docs/apis/common.md | 6 +- docs/apis/dataclass.md | 23 ++ mkdocs.yaml | 1 + pyproject.toml | 7 +- t4_devkit/common/box.py | 76 ------ t4_devkit/common/color.py | 4 +- t4_devkit/common/geometry.py | 12 +- t4_devkit/common/transform.py | 267 +++++++++++++++++++ t4_devkit/dataclass/__init__.py | 6 + t4_devkit/dataclass/box.py | 250 +++++++++++++++++ t4_devkit/dataclass/label.py | 208 +++++++++++++++ t4_devkit/dataclass/pointcloud.py | 198 ++++++++++++++ t4_devkit/dataclass/roi.py | 49 ++++ t4_devkit/dataclass/shape.py | 89 +++++++ t4_devkit/dataclass/trajectory.py | 88 ++++++ t4_devkit/schema/tables/ego_pose.py | 4 +- t4_devkit/schema/tables/registry.py | 8 +- t4_devkit/schema/tables/sample_annotation.py | 4 +- t4_devkit/tier4.py | 184 ++++++++----- t4_devkit/typing.py | 7 + tests/common/test_geometry.py | 3 +- tests/common/test_transform.py | 89 +++++++ tests/conftest.py | 120 +++++++++ tests/dataclass/test_box.py | 36 +++ tests/dataclass/test_label.py | 102 +++++++ tests/dataclass/test_roi.py | 13 + tests/dataclass/test_shape.py | 23 ++ tests/dataclass/test_trajectory.py | 56 ++++ 29 files changed, 1754 insertions(+), 180 deletions(-) create mode 100644 docs/apis/dataclass.md delete mode 100644 t4_devkit/common/box.py create mode 100644 t4_devkit/common/transform.py create mode 100644 t4_devkit/dataclass/__init__.py create mode 100644 t4_devkit/dataclass/box.py create mode 100644 t4_devkit/dataclass/label.py create mode 100644 t4_devkit/dataclass/pointcloud.py create mode 100644 t4_devkit/dataclass/roi.py create mode 100644 t4_devkit/dataclass/shape.py create mode 100644 t4_devkit/dataclass/trajectory.py create mode 100644 tests/common/test_transform.py create mode 100644 tests/conftest.py create mode 100644 tests/dataclass/test_box.py create mode 100644 tests/dataclass/test_label.py create mode 100644 tests/dataclass/test_roi.py create mode 100644 tests/dataclass/test_shape.py create mode 100644 tests/dataclass/test_trajectory.py diff --git a/.gitignore b/.gitignore index 059b8bc..b06fe3c 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +# Editor configurations .vscode diff --git a/docs/apis/common.md b/docs/apis/common.md index 4ebb491..df4310e 100644 --- a/docs/apis/common.md +++ b/docs/apis/common.md @@ -1,10 +1,6 @@ # `common` -::: t4_devkit.common.box - options: - show_bases: false - ::: t4_devkit.common.color ::: t4_devkit.common.geometry @@ -12,4 +8,6 @@ ::: t4_devkit.common.io ::: t4_devkit.common.timestamp + +::: t4_devkit.common.transform diff --git a/docs/apis/dataclass.md b/docs/apis/dataclass.md new file mode 100644 index 0000000..b01e8c9 --- /dev/null +++ b/docs/apis/dataclass.md @@ -0,0 +1,23 @@ +# `dataclass` + + +::: t4_devkit.dataclass.box + options: + filters: ["!BaseBox"] + show_bases: false + +::: t4_devkit.dataclass.label + options: + show_bases: false + +::: t4_devkit.dataclass.pointcloud + options: + show_bases: false + +::: t4_devkit.dataclass.roi + +::: t4_devkit.dataclass.shape + +::: t4_devkit.dataclass.trajectory + + diff --git a/mkdocs.yaml b/mkdocs.yaml index e4d54ef..a16eb07 100644 --- a/mkdocs.yaml +++ b/mkdocs.yaml @@ -9,6 +9,7 @@ nav: - API: - TIER IV: apis/tier4.md - Schema: apis/schema.md + - DataClass: apis/dataclass.md - Common: apis/common.md theme: diff --git a/pyproject.toml b/pyproject.toml index 79afcda..6d1e375 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,9 @@ packages = [{ include = "t4_devkit" }] [tool.poetry.dependencies] python = ">=3.10,<3.13" rerun-sdk = "0.17.0" -nuscenes-devkit = "^1.1.11" +pyquaternion = "^0.9.9" +matplotlib = "^3.9.2" +shapely = "<2.0.0" [tool.poetry.group.dev.dependencies] pytest = "^8.2.2" @@ -24,3 +26,6 @@ ruff = "^0.6.8" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 100 diff --git a/t4_devkit/common/box.py b/t4_devkit/common/box.py deleted file mode 100644 index 27d05a9..0000000 --- a/t4_devkit/common/box.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, List, Tuple - -import numpy as np -from nuscenes.utils.data_classes import Box -from pyquaternion import Quaternion - -if TYPE_CHECKING: - from t4_devkit.typing import RoiType - -__all__ = ("Box3D", "Box2D") - - -class Box3D(Box): - """An wrapper of NuScenes Box.""" - - def __init__( - self, - center: List[float], - size: List[float], - orientation: Quaternion, - label: int = np.nan, - score: float = np.nan, - velocity: Tuple = (np.nan, np.nan, np.nan), - name: str = None, - token: str = None, - ) -> None: - """Construct instance. - - Args: - center (List[float]): Center of box given as (x, y, z). - size (List[float]): Size of box given as (width, length, height). - orientation (Quaternion): Box orientation. - label (int, optional): Integer label. - score (float, optional): Classification score. - velocity (Tuple, optional): Box velocity given as (vx, vy, vz). - name (str, optional): Box category name. - token (str, optional): Unique string identifier. - """ - super().__init__(center, size, orientation, label, score, velocity, name, token) - - -class Box2D: - """A class to represent 2D box.""" - - def __init__( - self, - roi: RoiType, - label: int = -1, - score: float = np.nan, - name: str | None = None, - token: str | None = None, - ) -> None: - """Construct instance. - - Args: - roi (RoiType): Roi elements, which is the order of (xmin, ymin, xmax, ymax). - label (int, optional): Box label. - score (float, optional): Box score. - name (str | None, optional): Category name. - token (str | None, optional): Unique identifier token corresponding to `token` of `object_ann`. - """ - self.xmin, self.ymin, self.xmax, self.ymax = roi - self.label = int(label) - self.score = float(score) if not np.isnan(score) else score - self.name = name - self.token = token - - @property - def width(self) -> int: - return self.xmax - self.xmin - - @property - def height(self) -> int: - return self.ymax - self.ymin diff --git a/t4_devkit/common/color.py b/t4_devkit/common/color.py index 5f76a9a..d00b289 100644 --- a/t4_devkit/common/color.py +++ b/t4_devkit/common/color.py @@ -28,8 +28,6 @@ def distance_color( Color map in the shape of (N,). If input type is any number, returns a color as `tuple[float, float, float]`. Otherwise, returns colors as `NDArrayF64`. """ - color_map = ( - matplotlib.colormaps["turbo_r"] if cmap is None else matplotlib.colormaps[cmap] - ) + color_map = matplotlib.colormaps["turbo_r"] if cmap is None else matplotlib.colormaps[cmap] norm = matplotlib.colors.Normalize(v_min, v_max) return color_map(norm(distances)) diff --git a/t4_devkit/common/geometry.py b/t4_devkit/common/geometry.py index b9e2b9a..e0c57a5 100644 --- a/t4_devkit/common/geometry.py +++ b/t4_devkit/common/geometry.py @@ -3,13 +3,13 @@ from typing import TYPE_CHECKING import numpy as np + from t4_devkit.schema import VisibilityLevel if TYPE_CHECKING: + from t4_devkit.dataclass import Box3D from t4_devkit.typing import NDArrayF64 - from .box import Box3D - __all__ = ("view_points", "is_box_in_image") @@ -56,9 +56,7 @@ def view_points( x_ = points[0] y_ = points[1] r2 = x_**2 + y_**2 - f1 = (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) / ( - 1 + k4 * r2 + k5 * r2**2 + k6 * r2**3 - ) + f1 = (1 + k1 * r2 + k2 * r2**2 + k3 * r2**3) / (1 + k4 * r2 + k5 * r2**2 + k6 * r2**3) f2 = x_ * y_ x__ = x_ * f1 + 2 * p1 * f2 + p2 * (r2 + 2 * x_**2) + s1 * r2 + s2 * r2**2 y__ = y_ * f1 + p1 * (r2 + 2 * y_**2) + 2 * p2 * f2 + s3 * r2 + s4 * r2**2 @@ -101,9 +99,7 @@ def is_box_in_image( is_visible = np.logical_and(is_visible, corners_on_img[1, :] > 0) is_visible = np.logical_and(is_visible, corners_on_img[2, :] > 1) - in_front = ( - corners_3d[2, :] > 0.1 - ) # True if a corner is at least 0.1 meter in front of camera. + in_front = corners_3d[2, :] > 0.1 # True if a corner is at least 0.1 meter in front of camera. if visibility == VisibilityLevel.FULL: return all(is_visible) and all(in_front) diff --git a/t4_devkit/common/transform.py b/t4_devkit/common/transform.py new file mode 100644 index 0000000..ea5408d --- /dev/null +++ b/t4_devkit/common/transform.py @@ -0,0 +1,267 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, overload + +import numpy as np +from pyquaternion import Quaternion +from typing_extensions import Self + +if TYPE_CHECKING: + from t4_devkit.typing import ArrayLike, NDArray, RotationType + + +class HomogeneousMatrix: + def __init__( + self, + position: ArrayLike, + rotation: ArrayLike | RotationType, + src: str, + dst: str, + ) -> None: + """Construct a new object. + + Args: + position (ArrayLike): 3D position. + rotation (ArrayLike | RotationType): 3x3 rotation matrix or quaternion. + src (str): Source frame ID. + dst (str): Destination frame ID. + """ + self.position = position if isinstance(position, np.ndarray) else np.array(position) + + if isinstance(rotation, np.ndarray) and rotation.ndim == 2: + rotation = Quaternion(matrix=rotation) + elif not isinstance(rotation, Quaternion): + rotation = Quaternion(rotation) + self.rotation = rotation + + self.src = src + self.dst = dst + + self.matrix = _generate_homogeneous_matrix(position, rotation) + + @classmethod + def from_matrix( + cls, + matrix: NDArray | HomogeneousMatrix, + src: str, + dst: str, + ) -> Self: + """Construct a new object from a homogeneous matrix. + + Args: + matrix (NDArray | HomogeneousMatrix): 4x4 homogeneous matrix. + src (str): Source frame ID. + dst (str): Destination frame ID. + + Returns: + Self: Constructed instance. + """ + position, rotation = _extract_position_and_rotation_from_matrix(matrix) + return cls(position, rotation, src, dst) + + @property + def shape(self) -> tuple[int, ...]: + """Return a shape of the homogeneous matrix. + + Returns: + Return the shape of (4, 4). + """ + return self.matrix.shape + + @property + def yaw_pitch_roll(self) -> tuple[float, float, float]: + """Return yaw, pitch and roll. + + NOTE: + yaw: Rotation angle around the z-axis in [rad], in the range `[-pi, pi]`. + pitch: Rotation angle around the y'-axis in [rad], in the range `[-pi/2, pi/2]`. + roll: Rotation angle around the x"-axis in [rad], in the range `[-pi, pi]`. + + Returns: + Yaw, pitch and roll in [rad]. + """ + return self.rotation.yaw_pitch_roll + + @property + def rotation_matrix(self) -> NDArray: + """Return a 3x3 rotation matrix. + + Returns: + 3x3 rotation matrix. + """ + return self.rotation.rotation_matrix + + def dot(self, other: HomogeneousMatrix) -> HomogeneousMatrix: + """Return a dot product of myself and another. + + Args: + other (HomogeneousMatrix): `HomogeneousMatrix` object. + + Raises: + ValueError: `self.src` and `other.dst` must be the same frame ID. + + Returns: + Result of a dot product. + """ + if self.src != other.dst: + raise ValueError(f"self.src != other.dst: self.src={self.src}, other.dst={other.dst}") + + ret_mat = self.matrix.dot(other.matrix) + position, rotation = _extract_position_and_rotation_from_matrix(ret_mat) + return HomogeneousMatrix(position, rotation, src=other.src, dst=self.dst) + + def inv(self) -> HomogeneousMatrix: + """Return a inverse matrix of myself. + + Returns: + Inverse matrix. + """ + ret_mat = np.linalg.inv(self.matrix) + position, rotation = _extract_position_and_rotation_from_matrix(ret_mat) + return HomogeneousMatrix(position, rotation, src=self.src, dst=self.dst) + + @overload + def transform(self, position: ArrayLike) -> NDArray: + """Transform a position by myself. + + Args: + position (ArrayLike): 3D position. + + Returns: + Transformed position. + """ + pass + + @overload + def transform( + self, + position: ArrayLike, + rotation: RotationType, + ) -> tuple[NDArray, Quaternion]: + """Transform position and rotation by myself. + + Args: + position (ArrayLike): 3D position. + rotation (RotationType): 3x3 rotation matrix or quaternion. + + Returns: + Transformed position and quaternion. + """ + pass + + @overload + def transform(self, matrix: HomogeneousMatrix) -> HomogeneousMatrix: + """Transform a homogeneous matrix by myself. + + Args: + matrix (HomogeneousMatrix): `HomogeneousMatrix` object. + + Returns: + Transformed `HomogeneousMatrix` object. + """ + pass + + def transform(self, *args, **kwargs): + # TODO(ktro2828): Refactoring this operations. + s = len(args) + if s == 0: + if not kwargs: + raise ValueError("At least 1 argument specified") + + if "position" in kwargs: + position = kwargs["position"] + if "matrix" in kwargs: + raise ValueError("Cannot specify `position` and `matrix` at the same time.") + elif "rotation" in kwargs: + rotation = kwargs["rotation"] + return self.__transform_position_and_rotation(position, rotation) + else: + return self.__transform_position(position) + elif "matrix" in kwargs: + matrix = kwargs["matrix"] + return self.__transform_matrix(matrix) + else: + raise KeyError(f"Unexpected keys are detected: {list(kwargs.keys())}") + elif s == 1: + arg = args[0] + if isinstance(arg, HomogeneousMatrix): + return self.__transform_matrix(matrix=arg) + else: + return self.__transform_position(position=arg) + elif s == 2: + position, rotation = args + return self.__transform_position_and_rotation(position, rotation) + else: + raise ValueError(f"Unexpected number of arguments {s}") + + def __transform_position(self, position: ArrayLike) -> NDArray: + rotation = Quaternion() + matrix = _generate_homogeneous_matrix(position, rotation) + ret_mat = self.matrix.dot(matrix) + ret_pos, _ = _extract_position_and_rotation_from_matrix(ret_mat) + return ret_pos + + def __transform_position_and_rotation( + self, + position: ArrayLike, + rotation: RotationType, + ) -> tuple[NDArray, Quaternion]: + matrix = _generate_homogeneous_matrix(position, rotation) + ret_mat = self.matrix.dot(matrix) + return _extract_position_and_rotation_from_matrix(ret_mat) + + def __transform_matrix(self, matrix: HomogeneousMatrix) -> HomogeneousMatrix: + return matrix.dot(self) + + +def _extract_position_and_rotation_from_matrix( + matrix: NDArray | HomogeneousMatrix, +) -> tuple[NDArray, Quaternion]: + """Extract position and rotation from a homogeneous matrix. + + Args: + matrix (NDArray | HomogeneousMatrix): 4x4 matrix or `HomogeneousMatrix` object. + + Raises: + ValueError: Matrix shape must be 4x4. + + Returns: + 3D position and quaternion. + """ + if isinstance(matrix, np.ndarray): + if matrix.shape != (4, 4): + raise ValueError(f"Homogeneous matrix must be 4x4, but got {matrix.shape}") + + position = matrix[:3, 3] + rotation = matrix[:3, :3] + return position, Quaternion(matrix=rotation) + else: + return matrix.position, matrix.rotation + + +def _generate_homogeneous_matrix( + position: ArrayLike, + rotation: ArrayLike | RotationType, +) -> NDArray: + """Generate a 4x4 homogeneous matrix from position and rotation. + + Args: + position (ArrayLike): 3D position. + rotation (ArrayLike | RotationType): 3x3 rotation matrix or quaternion. + + Returns: + A 4x4 homogeneous matrix. + """ + if not isinstance(position, np.ndarray): + position = np.array(position) + + if not isinstance(rotation, Quaternion): + if isinstance(rotation, np.ndarray) and rotation.ndim == 2: + rotation = Quaternion(matrix=rotation) + else: + rotation = Quaternion(rotation) + + matrix = np.eye(4) + matrix[:3, 3] = position + matrix[:3, :3] = rotation.rotation_matrix + return matrix diff --git a/t4_devkit/dataclass/__init__.py b/t4_devkit/dataclass/__init__.py new file mode 100644 index 0000000..f76005a --- /dev/null +++ b/t4_devkit/dataclass/__init__.py @@ -0,0 +1,6 @@ +from .box import * # noqa +from .label import * # noqa +from .pointcloud import * # noqa +from .roi import * # noqa +from .shape import * # noqa +from .trajectory import * # noqa diff --git a/t4_devkit/dataclass/box.py b/t4_devkit/dataclass/box.py new file mode 100644 index 0000000..1868bbe --- /dev/null +++ b/t4_devkit/dataclass/box.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, TypeVar + +import numpy as np +from pyquaternion import Quaternion +from shapely.geometry import Polygon +from typing_extensions import Self + +from .roi import Roi +from .trajectory import to_trajectories + +if TYPE_CHECKING: + from t4_devkit.typing import ( + NDArrayF64, + RotationType, + SizeType, + TrajectoryType, + TranslationType, + VelocityType, + ) + + from .label import SemanticLabel + from .shape import Shape + from .trajectory import Trajectory + + +__all__ = ["Box3D", "Box2D", "BoxType"] + + +@dataclass(eq=False) +class BaseBox: + """Abstract base class for box objects.""" + + unix_time: int + frame_id: str + semantic_label: SemanticLabel + confidence: float = field(default=1.0, kw_only=True) + uuid: str | None = field(default=None, kw_only=True) + + +@dataclass(eq=False) +class Box3D(BaseBox): + """A class to represent 3D box. + + Attributes: + unix_time (int): Unix timestamp. + frame_id (str): Coordinates frame ID where the box is with respect to. + semantic_label (SemanticLabel): `SemanticLabel` object. + confidence (float, optional): Confidence score of the box. + uuid (str | None, optional): Unique box identifier. + position (TranslationType): Box center position (x, y, z). + rotation (RotationType): Box rotation quaternion. + shape (Shape): `Shape` object. + velocity (VelocityType | None, optional): Box velocity (vx, vy, vz). + num_points (int | None, optional): The number of points inside the box. + future (list[Trajectory] | None, optional): Box trajectory in the future of each mode. + + Examples: + >>> # without future + >>> box3d = Box3D( + ... unix_time=100, + ... frame_id="base_link", + ... semantic_label=SemanticLabel(label=LabelID.CAR, original="car"), + ... position=(1.0, 1.0, 1.0), + ... rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), + ... shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), + ... velocity=(1.0, 1.0, 1.0), + ... confidence=1.0, + ... uuid="car3d_0", + ... ) + >>> # with future + >>> box3d = box3d.with_future( + ... waypoints=[[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]]], + ... confidences=[1.0], + ... ) + """ + + position: TranslationType + rotation: RotationType + shape: Shape + velocity: VelocityType | None = field(default=None) + num_points: int | None = field(default=None) + + # additional attributes: set by `with_**` + future: list[Trajectory] | None = field(default=None, init=False) + + def __post_init__(self) -> None: + if not isinstance(self.position, np.ndarray): + self.position = np.array(self.position) + + if not isinstance(self.rotation, Quaternion): + self.rotation = Quaternion(self.rotation) + + if self.velocity is not None and not isinstance(self.velocity, np.ndarray): + self.velocity = np.array(self.velocity) + + def with_future( + self, + waypoints: list[TrajectoryType], + confidences: list[float], + ) -> Self: + """Return a self instance setting `future` attribute. + + Args: + waypoints (list[TrajectoryType]): List of waypoints for each mode. + confidences (list[float]): List of confidences for each mode. + + Returns: + Self instance after setting `future`. + """ + self.future = to_trajectories(waypoints, confidences) + return self + + def __eq__(self, other: Box3D | None) -> bool: + if other is None: + return False + else: + # NOTE: This comparison might be not enough + eq = True + eq &= self.unix_time == other.unix_time + eq &= self.semantic_label == other.semantic_label + eq &= self.position == other.position + eq &= self.rotation == other.rotation + return eq + + @property + def size(self) -> SizeType: + return self.shape.size + + @property + def footprint(self) -> Polygon: + return self.shape.footprint + + @property + def area(self) -> float: + return self.shape.footprint.area + + @property + def volume(self) -> float: + return self.area * self.size[2] + + def corners(self, box_scale: float = 1.0) -> NDArrayF64: + """Return the bounding box corners. + + Args: + box_scale (float, optional): Multiply size by this factor to scale the box. + + Returns: + First four corners are the ones facing forward. The last four are the ones facing backwards, + in the shape of (8, 3). + """ + length, width, height = self.size * box_scale + + # 3D box corners (Convention: x points forward, y to the left, z up.) + x_corners = 0.5 * length * np.array([1, 1, 1, 1, -1, -1, -1, -1]) + y_corners = 0.5 * width * np.array([1, -1, -1, 1, 1, -1, -1, 1]) + z_corners = 0.5 * height * np.array([1, 1, -1, -1, 1, 1, -1, -1]) + corners = np.vstack((x_corners, y_corners, z_corners)) # (3, 8) + + # Rotate and translate + return np.dot(self.rotation.rotation_matrix, corners).T + self.position + + +@dataclass(eq=False) +class Box2D(BaseBox): + """A class to represent 2D box. + + Attributes: + unix_time (int): Unix timestamp. + frame_id (str): Coordinates frame ID where the box is with respect to. + semantic_label (SemanticLabel): `SemanticLabel` object. + confidence (float, optional): Confidence score of the box. + uuid (str | None, optional): Unique box identifier. + roi (Roi | None, optional): `Roi` object. + position (TranslationType | None, optional): 3D position (x, y, z). + + Examples: + >>> # without 3D position + >>> box2d = Box2D( + ... unix_time=100, + ... frame_id="camera", + ... semantic_label=SemanticLabel(label=LabelID.CAR, original="car"), + ... roi=(100, 100, 50, 50), + ... confidence=1.0, + ... uuid="car2d_0", + ... ) + >>> # with 3D position + >>> box2d = box2d.with_position(position=(1.0, 1.0, 1.0)) + """ + + roi: Roi | None = field(default=None) + + # additional attributes: set by `with_**` + position: TranslationType | None = field(default=None, init=False) + + def __post_init__(self) -> None: + if self.roi is not None and not isinstance(self.roi, Roi): + self.roi = Roi(self.roi) + + def with_position(self, position: TranslationType) -> Self: + """Return a self instance setting `position` attribute. + + Args: + position (TranslationType): 3D position. + + Returns: + Self instance after setting `position`. + """ + self.position = np.array(position) if not isinstance(position, np.ndarray) else position + return self + + def __eq__(self, other: Box2D | None) -> bool: + if other is None: + return False + else: + # NOTE: This comparison might be not enough + eq = True + eq &= self.unix_time == other.unix_time + eq &= self.semantic_label == other.semantic_label + return eq + + @property + def offset(self) -> tuple[int, int] | None: + return None if self.roi is None else self.roi.offset + + @property + def size(self) -> tuple[int, int] | None: + return None if self.roi is None else self.roi.size + + @property + def width(self) -> int | None: + return None if self.roi is None else self.roi.width + + @property + def height(self) -> int | None: + return None if self.roi is None else self.roi.height + + @property + def center(self) -> tuple[int, int] | None: + return None if self.roi is None else self.roi.center + + @property + def area(self) -> int | None: + return None if self.roi is None else self.roi.area + + +# type aliases +BoxType = TypeVar("BoxType", bound=BaseBox) diff --git a/t4_devkit/dataclass/label.py b/t4_devkit/dataclass/label.py new file mode 100644 index 0000000..2774cd1 --- /dev/null +++ b/t4_devkit/dataclass/label.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +import warnings +from dataclasses import dataclass, field +from enum import Enum, auto, unique + +from typing_extensions import Self + +__all__ = ["LabelID", "SemanticLabel", "convert_label"] + + +@unique +class LabelID(Enum): + """Enum of label elements.""" + + # catch all labels + UNKNOWN = 0 + + # object labels + CAR = auto() + TRUCK = auto() + BUS = auto() + BICYCLE = auto() + MOTORBIKE = auto() + PEDESTRIAN = auto() + ANIMAL = auto() + + # traffic-light labels + TRAFFIC_LIGHT = auto() + GREEN = auto() + GREEN_STRAIGHT = auto() + GREEN_LEFT = auto() + GREEN_RIGHT = auto() + YELLOW = auto() + YELLOW_STRAIGHT = auto() + YELLOW_LEFT = auto() + YELLOW_RIGHT = auto() + YELLOW_STRAIGHT_LEFT = auto() + YELLOW_STRAIGHT_RIGHT = auto() + YELLOW_STRAIGHT_LEFT_RIGHT = auto() + RED = auto() + RED_STRAIGHT = auto() + RED_LEFT = auto() + RED_RIGHT = auto() + RED_STRAIGHT_LEFT = auto() + RED_STRAIGHT_RIGHT = auto() + RED_STRAIGHT_LEFT_RIGHT = auto() + RED_LEFT_DIAGONAL = auto() + RED_RIGHT_DIAGONAL = auto() + + @classmethod + def from_name(cls, name: str) -> Self: + name = name.upper() + assert name in cls.__members__, f"Unexpected label name: {name}" + return cls.__members__[name] + + def __eq__(self, other: LabelID | str) -> bool: + return self.name == other.upper() if isinstance(other, str) else self.name == other.name + + +@dataclass(frozen=True, eq=False) +class SemanticLabel: + """A dataclass to represent semantic labels. + + Attributes: + label (LabelID): Label ID. + original (str): Original name of the label. + attributes (list): List of attribute names. + """ + + label: LabelID + original: str + attributes: list[str] = field(default_factory=list) + + def __eq__(self, other: SemanticLabel) -> bool: + return self.label == other.label + + +# ===================== +# Label conversion +# ===================== + +# Name mapping (key: value) = (original: Label enum) +DEFAULT_NAME_MAPPING: dict[str, str] = { + # === ObjectLabel === + # CAR + "car": "CAR", + "vehicle.car": "CAR", + "vehicle.construction": "CAR", + "vehicle.emergency (ambulance & police)": "CAR", + "vehicle.police": "CAR", + # TRUCK + "truck": "TRUCK", + "vehicle.truck": "TRUCK", + "trailer": "TRUCK", + "vehicle.trailer": "TRUCK", + # BUS + "bus": "BUS", + "vehicle.bus": "BUS", + "vehicle.bus (bendy & rigid)": "BUS", + # BICYCLE + "bicycle": "BICYCLE", + "vehicle.bicycle": "BICYCLE", + # MOTORBIKE + "motorbike": "MOTORBIKE", + "vehicle.motorbike": "MOTORBIKE", + "motorcycle": "MOTORBIKE", + "vehicle.motorcycle": "MOTORBIKE", + # PEDESTRIAN + "pedestrian": "PEDESTRIAN", + "pedestrian.child": "PEDESTRIAN", + "pedestrian.personal_mobility": "PEDESTRIAN", + "pedestrian.police_officer": "PEDESTRIAN", + "pedestrian.stroller": "PEDESTRIAN", + "pedestrian.wheelchair": "PEDESTRIAN", + "construction_worker": "PEDESTRIAN", + # ANIMAL + "animal": "ANIMAL", + # UNKNOWN + "movable_object.barrier": "UNKNOWN", + "movable_object.debris": "UNKNOWN", + "movable_object.pushable_pullable": "UNKNOWN", + "movable_object.trafficcone": "UNKNOWN", + "movable_object.traffic_cone": "UNKNOWN", + "static_object.bicycle_lack": "UNKNOWN", + "static_object.bollard": "UNKNOWN", + "forklift": "UNKNOWN", + # === TrafficLightLabel === + # GREEN + "green": "GREEN", + "green_straight": "GREEN_STRAIGHT", + "green_left": "GREEN_LEFT", + "green_right": "GREEN_RIGHT", + # YELLOW + "yellow": "YELLOW", + "yellow_straight": "YELLOW_STRAIGHT", + "yellow_left": "YELLOW_LEFT", + "yellow_right": "YELLOW_RIGHT", + "yellow_straight_left": "YELLOW_STRAIGHT_LEFT", + "yellow_straight_right": "YELLOW_STRAIGHT_RIGHT", + "yellow_straight_left_right": "YELLOW_STRAIGHT_LEFT_RIGHT", + # RED + "red": "RED", + "red_straight": "RED_STRAIGHT", + "red_left": "RED_LEFT", + "red_right": "RED_RIGHT", + "red_straight_left": "RED_STRAIGHT_LEFT", + "red_straight_right": "RED_STRAIGHT_RIGHT", + "red_straight_left_right": "RED_STRAIGHT_LEFT_RIGHT", + "red_straight_left_diagonal": "RED_LEFT_DIAGONAL", + "red_straight_leftdiagonal": "RED_LEFT_DIAGONAL", + "red_straight_right_diagonal": "RED_RIGHT_DIAGONAL", + "red_straight_rightdiagonal": "RED_RIGHT_DIAGONAL", + # CROSSWALK + "crosswalk_red": "RED", + "crosswalk_green": "GREEN", + "crosswalk_unknown": "UNKNOWN", + "unknown": "UNKNOWN", +} + + +def convert_label( + original: str, + attributes: list[str] | None = None, + *, + name_mapping: dict[str, str] | None = None, + update_default_mapping: bool = False, +) -> SemanticLabel: + """Covert string original label name to `SemanticLabel` object. + + Args: + original (str): Original label name. For example, `vehicle.car`. + attributes (list[str] | None, optional): List of label attributes. + name_mapping (dict[str, str] | None, optional): Name mapping for original and label. + If `None`, `DEFAULT_NAME_MAPPING` will be used. + update_default_mapping (bool, optional): Whether to update `DEFAULT_NAME_MAPPING` by + `name_mapping`. If `False` and `name_mapping` is specified, + the specified `name_mapping` is used instead of `DEFAULT_NAME_MAPPING` completely. + Note that, this parameter works only if `name_mapping` is specified. + + Returns: + Converted `SemanticLabel` object. + """ + global DEFAULT_NAME_MAPPING + + # set name mapping + if name_mapping is None: + name_mapping = DEFAULT_NAME_MAPPING + elif update_default_mapping: + DEFAULT_NAME_MAPPING.update(name_mapping) + + # convert original to name for Label object + if original in name_mapping: + name = name_mapping[original] + else: + warnings.warn( + f"{original} is not included in mapping, use UNKNOWN.", + UserWarning, + ) + name = "UNKNOWN" + + label = LabelID.from_name(name) + + return ( + SemanticLabel(label, original) + if attributes is None + else SemanticLabel(label, original, attributes) + ) diff --git a/t4_devkit/dataclass/pointcloud.py b/t4_devkit/dataclass/pointcloud.py new file mode 100644 index 0000000..6278813 --- /dev/null +++ b/t4_devkit/dataclass/pointcloud.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +import struct +from abc import abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar + +import numpy as np + +if TYPE_CHECKING: + from typing_extensions import Self + + from t4_devkit.typing import NDArrayFloat, NDArrayU8 + +__all__ = ["LidarPointCloud", "RadarPointCloud", "SegmentationPointCloud"] + + +@dataclass +class PointCloud: + """Abstract base dataclass for pointcloud data.""" + + points: NDArrayFloat + + def __post_init__(self) -> None: + assert self.points.shape[0] == self.num_dims() + + @staticmethod + @abstractmethod + def num_dims() -> int: + """Return the number of the point dimensions. + + Returns: + int: The number of the point dimensions. + """ + pass + + @classmethod + @abstractmethod + def from_file(cls, filepath: str) -> Self: + """Create an object from pointcloud file. + + Args: + filepath (str): File path of the pointcloud file. + + Returns: + Self instance. + """ + pass + + def num_points(self) -> int: + """Return the number of points. + + Returns: + int: _description_ + """ + return self.points.shape[1] + + def translate(self, x: NDArrayFloat) -> None: + for i in range(3): + self.points[i, :] = self.points[i, :] + x[i] + + def rotate(self, matrix: NDArrayFloat) -> None: + self.points[:3, :] = np.dot(matrix, self.points[:3, :]) + + def transform(self, matrix: NDArrayFloat) -> None: + self.points[:3, :] = matrix.dot( + np.vstack((self.points[:3, :], np.ones(self.num_points()))) + )[:3, :] + + +@dataclass +class LidarPointCloud(PointCloud): + """A dataclass to represent lidar pointcloud.""" + + @staticmethod + def num_dims() -> int: + return 4 + + @classmethod + def from_file(cls, filepath: str) -> Self: + assert filepath.endswith(".bin"), f"Unexpected filetype: {filepath}" + + scan = np.fromfile(filepath, dtype=np.float32) + points = scan.reshape((-1, 5))[:, : cls.num_dims()] + return cls(points.T) + + +@dataclass +class RadarPointCloud(PointCloud): + # class variables + invalid_states: ClassVar[list[int]] = [0] + dynprop_states: ClassVar[list[int]] = range(7) + ambig_states: ClassVar[list[int]] = [3] + + @staticmethod + def num_dims() -> int: + return 18 + + @classmethod + def from_file( + cls, + filepath: str, + invalid_states: list[int] | None = None, + dynprop_states: list[int] | None = None, + ambig_states: list[int] | None = None, + ) -> Self: + assert filepath.endswith(".pcd"), f"Unexpected filetype: {filepath}" + + metadata = [] + with open(filepath, "rb") as f: + for line in f: + line = line.strip().decode("utf-8") + metadata.append(line) + if line.startswith("DATA"): + break + + data_binary = f.read() + + # Get the header rows and check if they appear as expected. + assert metadata[0].startswith("#"), "First line must be comment" + assert metadata[1].startswith("VERSION"), "Second line must be VERSION" + sizes = metadata[3].split(" ")[1:] + types = metadata[4].split(" ")[1:] + counts = metadata[5].split(" ")[1:] + width = int(metadata[6].split(" ")[1]) + height = int(metadata[7].split(" ")[1]) + data = metadata[10].split(" ")[1] + feature_count = len(types) + assert width > 0 + assert len([c for c in counts if c != c]) == 0, "Error: COUNT not supported!" + assert height == 1, "Error: height != 0 not supported!" + assert data == "binary" + + # Lookup table for how to decode the binaries. + unpacking_lut = { + "F": {2: "e", 4: "f", 8: "d"}, + "I": {1: "b", 2: "h", 4: "i", 8: "q"}, + "U": {1: "B", 2: "H", 4: "I", 8: "Q"}, + } + types_str = "".join([unpacking_lut[t][int(s)] for t, s in zip(types, sizes)]) + + # Decode each point. + offset = 0 + point_count = width + points = [] + for i in range(point_count): + point = [] + for p in range(feature_count): + start_p = offset + end_p = start_p + int(sizes[p]) + assert end_p < len(data_binary) + point_p = struct.unpack(types_str[p], data_binary[start_p:end_p])[0] + point.append(point_p) + offset = end_p + points.append(point) + + # A NaN in the first point indicates an empty pointcloud. + point = np.array(points[0]) + if np.any(np.isnan(point)): + return cls(np.zeros((feature_count, 0))) + + # Convert to numpy matrix. + points = np.array(points).transpose() + + # If no parameters are provided, use default settings. + invalid_states = cls.invalid_states if invalid_states is None else invalid_states + dynprop_states = cls.dynprop_states if dynprop_states is None else dynprop_states + ambig_states = cls.ambig_states if ambig_states is None else ambig_states + + # Filter points with an invalid state. + valid = [p in invalid_states for p in points[-4, :]] + points = points[:, valid] + + # Filter by dynProp. + valid = [p in dynprop_states for p in points[3, :]] + points = points[:, valid] + + # Filter by ambig_state. + valid = [p in ambig_states for p in points[11, :]] + points = points[:, valid] + + return cls(points) + + +@dataclass +class SegmentationPointCloud(PointCloud): + labels: NDArrayU8 + + @staticmethod + def num_dims() -> int: + return 4 + + @classmethod + def from_file(cls, point_filepath: str, label_filepath: str) -> Self: + scan = np.fromfile(point_filepath, dtype=np.float32) + points = scan.reshape((-1, 5))[:, : cls.num_dims()] + labels = np.fromfile(label_filepath, dtype=np.uint8) + return cls(points.T, labels) diff --git a/t4_devkit/dataclass/roi.py b/t4_devkit/dataclass/roi.py new file mode 100644 index 0000000..f9e4a5f --- /dev/null +++ b/t4_devkit/dataclass/roi.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from t4_devkit.typing import RoiType + +__all__ = ["Roi"] + + +@dataclass +class Roi: + roi: RoiType + + def __post_init__(self) -> None: + assert len(self.roi) == 4, ( + "Expected roi is (x, y, width, height), " f"but got length with {len(self.roi)}." + ) + + if not isinstance(self.roi, tuple): + self.roi = tuple(self.roi) + + @property + def offset(self) -> tuple[int, int]: + return self.roi[:2] + + @property + def size(self) -> tuple[int, int]: + return self.roi[2:] + + @property + def width(self) -> int: + return self.size[0] + + @property + def height(self) -> int: + return self.size[1] + + @property + def center(self) -> tuple[int, int]: + ox, oy = self.offset + w, h = self.size + return ox + w // 2, oy + h // 2 + + @property + def area(self) -> int: + w, h = self.size + return w * h diff --git a/t4_devkit/dataclass/shape.py b/t4_devkit/dataclass/shape.py new file mode 100644 index 0000000..9fe9e08 --- /dev/null +++ b/t4_devkit/dataclass/shape.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum, auto, unique +from typing import TYPE_CHECKING + +import numpy as np +from shapely.geometry import Polygon +from typing_extensions import Self + +if TYPE_CHECKING: + from t4_devkit.typing import NDArrayF64, SizeType + + +__all__ = ["ShapeType", "Shape"] + + +@unique +class ShapeType(Enum): + BOUNDING_BOX = 0 + POLYGON = auto() + + @classmethod + def from_name(cls, name: str) -> Self: + """Return an enum object from the name of the member. + + Args: + name (str): Name of enum member. + + Returns: + Enum object. + """ + name = name.upper() + assert name in cls.__members__, f"Unexpected shape type: {name}." + return cls.__members__[name] + + +@dataclass +class Shape: + """A dataclass to represent the 3D box shape. + + Examples: + >>> shape = Shape( + ... shape_type=ShapeType.BOUNDING_BOX, + ... size=[1.0, 1.0, 1.0] + ... ) + """ + + shape_type: ShapeType + size: SizeType + footprint: Polygon = field(default=None) + + def __post_init__(self) -> None: + if not isinstance(self.size, np.ndarray): + self.size = np.array(self.size) + + if self.shape_type == ShapeType.POLYGON and self.footprint is None: + raise ValueError("`footprint` must be specified for `POLYGON`.") + + if self.footprint is None: + self.footprint = _calculate_footprint(self.size) + + +def _calculate_footprint(size: SizeType) -> Polygon: + """Return a footprint of box as `Polygon` object. + + Args: + size (SizeType): Size of box ordering in (length, width, height). + + Returns: + Footprint in a clockwise order started from the top-right corner. + """ + + corners: list[NDArrayF64] = [ + np.array([size[1], size[0], 0.0]) / 2.0, + np.array([-size[1], size[0], 0.0]) / 2.0, + np.array([-size[1], -size[0], 0.0]) / 2.0, + np.array([size[1], -size[0], 0.0]) / 2.0, + ] + + return Polygon( + [ + corners[0], + corners[1], + corners[2], + corners[3], + corners[0], + ] + ) diff --git a/t4_devkit/dataclass/trajectory.py b/t4_devkit/dataclass/trajectory.py new file mode 100644 index 0000000..ff013bc --- /dev/null +++ b/t4_devkit/dataclass/trajectory.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Generator + +import numpy as np + +if TYPE_CHECKING: + from t4_devkit.typing import TrajectoryType, TranslationType + +__all__ = ["Trajectory", "to_trajectories"] + + +@dataclass +class Trajectory: + """A dataclass to represent trajectory. + + Attributes: + waypoints (TrajectoryType): Waypoints matrix in the shape of (N, 3). + confidence (float, optional): Confidence score the trajectory. + + Examples: + >>> trajectory = Trajectory( + ... waypoints=[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], + ... confidence=1.0, + ... ) + # Get the number of waypoints. + >>> len(trajectory) + 2 + # Access the shape of waypoints matrix: (N, 3). + >>> trajectory.shape + (2, 3) + # Access each point as subscriptable. + >>> trajectory[0] + array([1., 1., 1.]) + # Access each point as iterable. + >>> for point in trajectory: + ... print(point) + ... + [1. 1. 1.] + [2. 2. 2.] + """ + + waypoints: TrajectoryType + confidence: float = field(default=1.0) + + def __post_init__(self) -> None: + if not isinstance(self.waypoints, np.ndarray): + self.waypoints = np.array(self.waypoints) + + assert self.waypoints.shape[1] == 3 + + def __len__(self) -> int: + return len(self.waypoints) + + def __getitem__(self, index: int) -> TranslationType: + return self.waypoints[index] + + def __iter__(self) -> Generator[TrajectoryType]: + yield from self.waypoints + + @property + def shape(self) -> tuple[int, ...]: + """Return the shape of the waypoints matrix. + + Returns: + Shape of the matrix (N, 3). + """ + return self.waypoints.shape + + +def to_trajectories( + waypoints: list[TrajectoryType], + confidences: list[float], +) -> list[Trajectory]: + """Convert a list of waypoints and confidences to a list of `Trajectory`s for each mode. + + Args: + waypoints (list[TrajectoryType]): List of waypoints for each mode. + confidences (list[float]): List of confidences for each mode. + + Returns: + List of `Trajectory`s for each mode. + """ + return [ + Trajectory(points, confidence) + for points, confidence in zip(waypoints, confidences, strict=True) + ] diff --git a/t4_devkit/schema/tables/ego_pose.py b/t4_devkit/schema/tables/ego_pose.py index 9624e09..6925378 100644 --- a/t4_devkit/schema/tables/ego_pose.py +++ b/t4_devkit/schema/tables/ego_pose.py @@ -41,6 +41,4 @@ def from_dict(cls, data: dict[str, Any]) -> Self: rotation = Quaternion(data["rotation"]) timestamp: int = data["timestamp"] - return cls( - token=token, translation=translation, rotation=rotation, timestamp=timestamp - ) + return cls(token=token, translation=translation, rotation=rotation, timestamp=timestamp) diff --git a/t4_devkit/schema/tables/registry.py b/t4_devkit/schema/tables/registry.py index 30cf6b8..16af39d 100644 --- a/t4_devkit/schema/tables/registry.py +++ b/t4_devkit/schema/tables/registry.py @@ -51,9 +51,7 @@ def _register_decorator(obj: object) -> object: return _register_decorator - def _add_module( - self, module: object, name: SchemaName, *, force: bool = False - ) -> None: + def _add_module(self, module: object, name: SchemaName, *, force: bool = False) -> None: if not inspect.isclass(module): raise TypeError(f"module must be a class, but got {type(module)}.") @@ -62,9 +60,7 @@ def _add_module( self.__schemas[name.value] = module - def build_from_json( - self, key: str | SchemaName, filepath: str - ) -> list[SchemaTable]: + def build_from_json(self, key: str | SchemaName, filepath: str) -> list[SchemaTable]: """Build schema dataclass from json. Args: diff --git a/t4_devkit/schema/tables/sample_annotation.py b/t4_devkit/schema/tables/sample_annotation.py index ea4be4b..623bb6a 100644 --- a/t4_devkit/schema/tables/sample_annotation.py +++ b/t4_devkit/schema/tables/sample_annotation.py @@ -84,9 +84,7 @@ def from_dict(cls, data: dict[str, Any]) -> Self: visibility_token: str = data["visibility_token"] translation = np.array(data["translation"]) velocity = np.array(data["velocity"]) if data.get("velocity") else None - acceleration = ( - np.array(data["acceleration"]) if data.get("acceleration") else None - ) + acceleration = np.array(data["acceleration"]) if data.get("acceleration") else None size = np.array(data["size"]) rotation = Quaternion(data["rotation"]) num_lidar_pts: int = data["num_lidar_pts"] diff --git a/t4_devkit/tier4.py b/t4_devkit/tier4.py index a131c31..8898b23 100644 --- a/t4_devkit/tier4.py +++ b/t4_devkit/tier4.py @@ -1,26 +1,35 @@ from __future__ import annotations -from dataclasses import dataclass, field import os import os.path as osp import time +from dataclasses import dataclass, field from typing import TYPE_CHECKING -from PIL import Image import numpy as np -from nuscenes.nuscenes import LidarPointCloud, RadarPointCloud -from pyquaternion import Quaternion import rerun as rr import rerun.blueprint as rrb -from t4_devkit.common.box import Box2D, Box3D +from PIL import Image +from pyquaternion import Quaternion + from t4_devkit.common.color import distance_color from t4_devkit.common.geometry import is_box_in_image, view_points from t4_devkit.common.timestamp import sec2us, us2sec +from t4_devkit.dataclass import ( + Box2D, + Box3D, + LidarPointCloud, + RadarPointCloud, + Shape, + ShapeType, + convert_label, +) from t4_devkit.schema import SchemaName, SensorModality, VisibilityLevel, build_schema if TYPE_CHECKING: from rerun.blueprint.api import BlueprintLike, Container, SpaceView from rerun.recording_stream import RecordingStream + from t4_devkit.typing import ( CamIntrinsicType, NDArrayF64, @@ -31,6 +40,7 @@ VelocityType, ) + from .dataclass import BoxType, SemanticLabel from .schema import ( Attribute, CalibratedSensor, @@ -98,9 +108,7 @@ def __init__(self, version: str, data_root: str, verbose: bool = True) -> None: self.verbose = verbose if not osp.exists(self.data_root): - raise FileNotFoundError( - f"Database directory is not found: {self.data_root}" - ) + raise FileNotFoundError(f"Database directory is not found: {self.data_root}") start_time = time.time() if verbose: @@ -218,9 +226,7 @@ def __make_reverse_index__(self, verbose: bool) -> None: sample_record.ann_3ds.append(ann_record.token) for ann_record in self.object_ann: - sd_record: SampleData = self.get( - "sample_data", ann_record.sample_data_token - ) + sd_record: SampleData = self.get("sample_data", ann_record.sample_data_token) sample_record: Sample = self.get("sample", sd_record.sample_token) sample_record.ann_2ds.append(ann_record.token) @@ -297,7 +303,7 @@ def get_sample_data( *, as_3d: bool = True, visibility: VisibilityLevel = VisibilityLevel.NONE, - ) -> tuple[str, list[Box3D | Box2D], CamIntrinsicType | None]: + ) -> tuple[str, list[BoxType], CamIntrinsicType | None]: """Return the data path as well as all annotations related to that `sample_data`. Args: @@ -329,7 +335,7 @@ def get_sample_data( img_size = None # Retrieve all sample annotations and map to sensor coordinate system. - boxes: list[Box3D | Box2D] + boxes: list[BoxType] if selected_ann_tokens is not None: boxes = ( list(map(self.get_box3d, selected_ann_tokens)) @@ -338,9 +344,7 @@ def get_sample_data( ) else: boxes = ( - self.get_box3ds(sample_data_token) - if as_3d - else self.get_box2ds(sample_data_token) + self.get_box3ds(sample_data_token) if as_3d else self.get_box2ds(sample_data_token) ) if not as_3d: @@ -368,6 +372,39 @@ def get_sample_data( return data_path, box_list, cam_intrinsic + def get_semantic_label( + self, + category_token: str, + attribute_tokens: list[str] | None = None, + name_mapping: dict[str, str] | None = None, + *, + update_default_mapping: bool = False, + ) -> SemanticLabel: + """Return a SemanticLabel instance from specified `category_token` and `attribute_tokens`. + + Args: + category_token (str): Token of `Category` table. + attribute_tokens (list[str] | None, optional): List of attribute tokens. + name_mapping (dict[str, str] | None, optional): Category name mapping. + update_default_mapping (bool, optional): Whether to update default category name mapping. + + Returns: + Instantiated SemanticLabel. + """ + category: Category = self.get("category", category_token) + attributes: list[str] = ( + [self.get("attribute", token).name for token in attribute_tokens] + if attribute_tokens is not None + else [] + ) + + return convert_label( + original=category.name, + attributes=attributes, + name_mapping=name_mapping, + update_default_mapping=update_default_mapping, + ) + def get_box3d(self, sample_annotation_token: str) -> Box3D: """Return a Box3D class from a `sample_annotation` record. @@ -377,15 +414,31 @@ def get_box3d(self, sample_annotation_token: str) -> Box3D: Returns: Instantiated Box3D. """ - record: SampleAnnotation = self.get( - "sample_annotation", sample_annotation_token + ann: SampleAnnotation = self.get("sample_annotation", sample_annotation_token) + instance: Instance = self.get("instance", ann.instance_token) + sample: Sample = self.get("sample", ann.sample_token) + + # semantic label + semantic_label = self.get_semantic_label( + category_token=instance.category_token, + attribute_tokens=ann.attribute_tokens, ) + + shape = Shape(shape_type=ShapeType.BOUNDING_BOX, size=ann.size) + + # velocity + velocity = self.box_velocity(sample_annotation_token=sample_annotation_token) + return Box3D( - record.translation, - record.size, - record.rotation, - name=record.category_name, - token=record.token, + unix_time=sample.timestamp, + frame_id="map", + semantic_label=semantic_label, + position=ann.translation, + rotation=ann.rotation, + shape=shape, + velocity=velocity, + confidence=1.0, + uuid=instance.token, # TODO(ktro2828): extract uuid from `instance_name`. ) def get_box2d(self, object_ann_token: str) -> Box2D: @@ -397,8 +450,23 @@ def get_box2d(self, object_ann_token: str) -> Box2D: Returns: Instantiated Box2D. """ - record: ObjectAnn = self.get("object_ann", object_ann_token) - return Box2D(record.bbox, name=record.category_name, token=record.token) + ann: ObjectAnn = self.get("object_ann", object_ann_token) + instance: Instance = self.get("instance", ann.instance_token) + sample_data: SampleData = self.get("sample_data", ann.sample_data_token) + + semantic_label = self.get_semantic_label( + category_token=ann.category_token, + attribute_tokens=ann.attribute_tokens, + ) + + return Box2D( + unix_time=sample_data.timestamp, + frame_id=sample_data.channel, + semantic_label=semantic_label, + roi=ann.bbox, + confidence=1.0, + uuid=instance.token, # TODO(ktro2828): extract uuid from `instance_name`. + ) def get_box3ds(self, sample_data_token: str) -> list[Box3D]: """Rerun a list of Box3D classes for all annotations of a particular `sample_data` record. @@ -422,12 +490,10 @@ def get_box3ds(self, sample_data_token: str) -> list[Box3D]: prev_sample_record: Sample = self.get("sample", curr_sample_record.prev) curr_ann_recs: list[SampleAnnotation] = [ - self.get("sample_annotation", token) - for token in curr_sample_record.ann_3ds + self.get("sample_annotation", token) for token in curr_sample_record.ann_3ds ] prev_ann_recs: list[SampleAnnotation] = [ - self.get("sample_annotation", token) - for token in prev_sample_record.ann_3ds + self.get("sample_annotation", token) for token in prev_sample_record.ann_3ds ] # Maps instance tokens to prev_ann records @@ -508,9 +574,7 @@ def box_velocity( Returns: VelocityType: Velocity in the order of (vx, vy, vz) in m/s. """ - current: SampleAnnotation = self.get( - "sample_annotation", sample_annotation_token - ) + current: SampleAnnotation = self.get("sample_annotation", sample_annotation_token) # If the real velocity is annotated, returns it if current.velocity is not None: @@ -581,13 +645,9 @@ def project_pointcloud( elif point_sample_data.modality == SensorModality.RADAR: pointcloud = RadarPointCloud.from_file(pc_filepath) else: - raise ValueError( - f"Expected sensor lidar/radar, but got {point_sample_data.modality}" - ) + raise ValueError(f"Expected sensor lidar/radar, but got {point_sample_data.modality}") - camera_sample_data: SampleData = self.get( - "sample_data", camera_sample_data_token - ) + camera_sample_data: SampleData = self.get("sample_data", camera_sample_data_token) if camera_sample_data.modality != SensorModality.CAMERA: f"Expected camera, but got {camera_sample_data.modality}" @@ -606,9 +666,7 @@ def project_pointcloud( pointcloud.translate(point_ego_pose.translation) # 3. transform from global into the ego vehicle frame for the timestamp of the image - camera_ego_pose: EgoPose = self.get( - "ego_pose", camera_sample_data.ego_pose_token - ) + camera_ego_pose: EgoPose = self.get("ego_pose", camera_sample_data.ego_pose_token) pointcloud.translate(-camera_ego_pose.translation) pointcloud.rotate(camera_ego_pose.rotation.rotation_matrix.T) @@ -698,9 +756,7 @@ def render_scene( self._render_annotation_2ds(scene.first_sample_token, max_timestamp_us) if save_dir is not None: - self._save_viewer( - save_dir, application_id + ".rrd", default_blueprint=blueprint - ) + self._save_viewer(save_dir, application_id + ".rrd", default_blueprint=blueprint) def render_instance( self, @@ -718,9 +774,7 @@ def render_instance( """ # search first sample associated with the instance instance: Instance = self.get("instance", instance_token) - first_ann: SampleAnnotation = self.get( - "sample_annotation", instance.first_annotation_token - ) + first_ann: SampleAnnotation = self.get("sample_annotation", instance.first_annotation_token) first_sample: Sample = self.get("sample", first_ann.sample_token) # search first sample data tokens @@ -749,9 +803,7 @@ def render_instance( spawn=show, ) - last_ann: SampleAnnotation = self.get( - "sample_annotation", instance.last_annotation_token - ) + last_ann: SampleAnnotation = self.get("sample_annotation", instance.last_annotation_token) last_sample: Sample = self.get("sample", last_ann.sample_token) max_timestamp_us = last_sample.timestamp @@ -774,9 +826,7 @@ def render_instance( ) if save_dir is not None: - self._save_viewer( - save_dir, application_id + ".rrd", default_blueprint=blueprint - ) + self._save_viewer(save_dir, application_id + ".rrd", default_blueprint=blueprint) def render_pointcloud( self, @@ -812,9 +862,7 @@ def render_pointcloud( # initialize viewer application_id = f"t4-devkit@{scene_token}" - blueprint = self._init_viewer( - application_id, render_annotation=False, spawn=show - ) + blueprint = self._init_viewer(application_id, render_annotation=False, spawn=show) first_lidar_sd_record: SampleData = self.get("sample_data", first_lidar_token) max_timestamp_us = first_lidar_sd_record.timestamp + sec2us(max_time_seconds) @@ -827,9 +875,7 @@ def render_pointcloud( ) if save_dir is not None: - self._save_viewer( - save_dir, application_id + ".rrd", default_blueprint=blueprint - ) + self._save_viewer(save_dir, application_id + ".rrd", default_blueprint=blueprint) def _init_viewer( self, @@ -868,9 +914,7 @@ def _init_viewer( if render_2d: camera_names = [ - sensor.channel - for sensor in self.sensor - if sensor.modality == SensorModality.CAMERA + sensor.channel for sensor in self.sensor if sensor.modality == SensorModality.CAMERA ] camera_space_views = [ rrb.Spatial2DView(name=camera, origin=f"world/ego_vehicle/{camera}") @@ -975,9 +1019,7 @@ def _render_lidar_and_ego( ) sensor_name = sample_data.channel - pointcloud = LidarPointCloud.from_file( - osp.join(self.data_root, sample_data.filename) - ) + pointcloud = LidarPointCloud.from_file(osp.join(self.data_root, sample_data.filename)) points = pointcloud.points[:3].T # (N, 3) point_colors = distance_color(np.linalg.norm(points, axis=1)) rr.log( @@ -994,9 +1036,7 @@ def _render_lidar_and_ego( current_lidar_token = sample_data.next - def _render_radars( - self, first_radar_tokens: list[str], max_timestamp_us: float - ) -> None: + def _render_radars(self, first_radar_tokens: list[str], max_timestamp_us: float) -> None: """Render radar pointcloud. Args: @@ -1027,9 +1067,7 @@ def _render_radars( ) current_radar_token = sample_data.next - def _render_cameras( - self, first_camera_tokens: list[str], max_timestamp_us: float - ) -> None: + def _render_cameras(self, first_camera_tokens: list[str], max_timestamp_us: float) -> None: """Render camera images. Args: @@ -1051,9 +1089,7 @@ def _render_cameras( sensor_name = sample_data.channel rr.log( f"world/ego_vehicle/{sensor_name}", - rr.ImageEncoded( - path=osp.join(self.data_root, sample_data.filename) - ), + rr.ImageEncoded(path=osp.join(self.data_root, sample_data.filename)), ) current_camera_token = sample_data.next diff --git a/t4_devkit/typing.py b/t4_devkit/typing.py index b891438..4843339 100644 --- a/t4_devkit/typing.py +++ b/t4_devkit/typing.py @@ -14,12 +14,15 @@ "NDArrayU8", "NDArrayBool", "NDArrayStr", + "NDArrayInt", + "NDArrayFloat", "TranslationType", "VelocityType", "AccelerationType", "RotationType", "VelocityType", "SizeType", + "TrajectoryType", "CamIntrinsicType", "CamDistortionType", "RoiType", @@ -36,12 +39,16 @@ NDArrayBool = NDArray[np.bool_] NDArrayStr = NDArray[np.str_] +NDArrayInt = NDArrayI32 | NDArrayI64 +NDArrayFloat = NDArrayF32 | NDArrayF64 + # 3D TranslationType = NewType("TranslationType", NDArrayF64) VelocityType = NewType("VelocityType", NDArrayF64) AccelerationType = NewType("AccelerationType", NDArrayF64) RotationType = NewType("RotationType", Quaternion) SizeType = NewType("SizeType", NDArrayF64) +TrajectoryType = NewType("TrajectoryType", NDArrayF64) CamIntrinsicType = NewType("CamIntrinsicType", NDArrayF64) CamDistortionType = NewType("CamDistortionType", NDArrayF64) diff --git a/tests/common/test_geometry.py b/tests/common/test_geometry.py index 42a1b81..1c9040f 100644 --- a/tests/common/test_geometry.py +++ b/tests/common/test_geometry.py @@ -1,4 +1,5 @@ import numpy as np + from t4_devkit.common.geometry import view_points @@ -67,8 +68,6 @@ def test_view_points_with_distortion() -> None: project = view_points(points, intrinsic, distortion) - print(project) - expect = np.array( [ [0.5413125, -0.5113125], diff --git a/tests/common/test_transform.py b/tests/common/test_transform.py new file mode 100644 index 0000000..412733b --- /dev/null +++ b/tests/common/test_transform.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import numpy as np + +from t4_devkit.common.transform import HomogeneousMatrix + + +def test_homogeneous_matrix_transform(): + ego2map = HomogeneousMatrix((1, 0, 0), (1, 0, 0, 0), src="base_link", dst="map") + pos1 = ego2map.transform((1, 0, 0)) + assert np.allclose(pos1, np.array((2, 0, 0))) + + pos2 = ego2map.transform(position=(1, 0, 0)) + assert np.allclose(pos2, np.array((2, 0, 0))) + + pos1, rot1 = ego2map.transform((1, 0, 0), (1, 0, 0, 0)) + assert np.allclose(pos1, np.array((2, 0, 0))) + assert np.allclose(rot1.rotation_matrix, np.eye(3)) + + pos2, rot2 = ego2map.transform(position=(1, 0, 0), rotation=(1, 0, 0, 0)) + assert np.allclose(pos2, np.array((2, 0, 0))) + assert np.allclose(rot2.rotation_matrix, np.eye(3)) + + map2ego = HomogeneousMatrix((-1, 0, 0), (1, 0, 0, 0), src="map", dst="base_link") + mat1 = ego2map.transform(map2ego) + assert np.allclose(mat1.matrix, np.eye(4)) + assert np.allclose(mat1.position, np.zeros(3)) + assert np.allclose(mat1.rotation_matrix, np.eye(3)) + + mat2 = ego2map.transform(matrix=map2ego) + assert np.allclose(mat2.matrix, np.eye(4)) + assert np.allclose(mat2.position, np.zeros(3)) + assert np.allclose(mat2.rotation_matrix, np.eye(3)) + + +def test_homogenous_matrix_dot(): + ego2map = HomogeneousMatrix((1, 1, 1), (1, 0, 0, 0), src="base_link", dst="map") + cam2ego = HomogeneousMatrix((2, 2, 2), (1, 0, 0, 0), src="camera", dst="base_link") + cam2map = ego2map.dot(cam2ego) + assert np.allclose( + cam2map.matrix, + np.array( + [ + [1, 0, 0, 3], + [0, 1, 0, 3], + [0, 0, 1, 3], + [0, 0, 0, 1], + ], + ), + ) + assert np.allclose(cam2map.position, np.array([3, 3, 3])) # cam position in map coords + assert np.allclose(cam2map.rotation_matrix, np.eye(3)) # cam rotation matrix in map coords + assert cam2map.src == "camera" + assert cam2map.dst == "map" + + +def test_homogenous_matrix_inv(): + matrix = np.array( + [ + [0.70710678, -0.70710678, 0.0, 1.0], + [0.70710678, 0.70710678, 0.0, 2.0], + [0.0, 0.0, 1.0, 3.0], + [0.0, 0.0, 0.0, 1.0], + ] + ) + ego2map = HomogeneousMatrix.from_matrix(matrix, src="base_link", dst="map") + inv = ego2map.inv() + assert np.allclose( + inv.matrix, + np.array( + [ + [0.70710678, 0.70710678, 0.0, -2.12132034], + [-0.70710678, 0.70710678, 0.0, -0.70710678], + [0.0, 0.0, 1.0, -3.0], + [0.0, 0.0, 0.0, 1.0], + ] + ), + ) + assert np.allclose(inv.position, np.array([-2.12132034, -0.70710678, -3.0])) + assert np.allclose( + inv.rotation_matrix, + np.array( + [ + [0.70710678, 0.70710678, 0.0], + [-0.70710678, 0.70710678, 0.0], + [0.0, 0.0, 1.0], + ] + ), + ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..23c6609 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,120 @@ +import pytest +from pyquaternion import Quaternion + +from t4_devkit.dataclass import Box2D, Box3D, LabelID, SemanticLabel, Shape, ShapeType + + +@pytest.fixture(scope="module") +def dummy_box3d() -> Box3D: + """Return a dummy 3D box. + + Returns: + A 3D box. + """ + return Box3D( + unix_time=100, + frame_id="base_link", + semantic_label=SemanticLabel(label=LabelID.CAR, original="car"), + position=(1.0, 1.0, 1.0), + rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), + shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), + velocity=(1.0, 1.0, 1.0), + confidence=1.0, + uuid="car3d_0", + ) + + +@pytest.fixture(scope="module") +def dummy_box3ds() -> list[Box3D]: + """Return a list of dummy 3D boxes. + + Returns: + List of 3D boxes. + """ + return [ + Box3D( + unix_time=100, + frame_id="base_link", + semantic_label=SemanticLabel(label=LabelID.CAR, original="car"), + position=(1.0, 1.0, 1.0), + rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), + shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), + velocity=(1.0, 1.0, 1.0), + confidence=1.0, + uuid="car3d_1", + ), + Box3D( + unix_time=100, + frame_id="base_link", + semantic_label=SemanticLabel(label=LabelID.BICYCLE, original="bicycle"), + position=(-1.0, -1.0, 1.0), + rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), + shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), + velocity=(1.0, 1.0, 1.0), + confidence=1.0, + uuid="bicycle3d_1", + ), + Box3D( + unix_time=100, + frame_id="base_link", + semantic_label=SemanticLabel(label=LabelID.PEDESTRIAN, original="pedestrian"), + position=(-1.0, 1.0, 1.0), + rotation=Quaternion([0.0, 0.0, 0.0, 1.0]), + shape=Shape(shape_type=ShapeType.BOUNDING_BOX, size=(1.0, 1.0, 1.0)), + velocity=(1.0, 1.0, 1.0), + confidence=1.0, + uuid="pedestrian3d_1", + ), + ] + + +@pytest.fixture(scope="module") +def dummy_box2d() -> Box2D: + """Return a dummy 2D box. + + Returns: + A 2D box. + """ + return Box2D( + unix_time=100, + frame_id="camera", + semantic_label=SemanticLabel(label=LabelID.CAR, original="car"), + roi=(100, 100, 50, 50), + confidence=1.0, + uuid="car2d_0", + ) + + +@pytest.fixture(scope="module") +def dummy_box2ds() -> list[Box2D]: + """Return a list of dummy 2D boxes. + + Returns: + List of 2D boxes. + """ + return [ + Box2D( + unix_time=100, + frame_id="camera", + semantic_label=SemanticLabel(label=LabelID.CAR, original="car"), + roi=(100, 100, 50, 50), + confidence=1.0, + uuid="car2d_1", + ), + Box2D( + unix_time=100, + frame_id="camera", + semantic_label=SemanticLabel(label=LabelID.BICYCLE, original="bicycle"), + roi=(50, 50, 10, 10), + confidence=1.0, + uuid="bicycle2d_1", + ), + Box2D( + unix_time=100, + frame_id="camera", + semantic_label=SemanticLabel(label=LabelID.PEDESTRIAN, original="pedestrian"), + roi=(150, 150, 20, 20), + confidence=1.0, + uuid="pedestrian2d_1", + ), + ] diff --git a/tests/dataclass/test_box.py b/tests/dataclass/test_box.py new file mode 100644 index 0000000..bdcc1c3 --- /dev/null +++ b/tests/dataclass/test_box.py @@ -0,0 +1,36 @@ +import numpy as np + + +def test_box3d(dummy_box3d) -> None: + """Test `Box3D` class.""" + # test properties + assert np.allclose(dummy_box3d.size, (1.0, 1.0, 1.0)) + assert dummy_box3d.area == 1.0 + assert dummy_box3d.volume == 1.0 + + assert np.allclose( + dummy_box3d.corners(box_scale=1.0), + np.array( + [ + [0.5, 0.5, 1.5], + [0.5, 1.5, 1.5], + [0.5, 1.5, 0.5], + [0.5, 0.5, 0.5], + [1.5, 0.5, 1.5], + [1.5, 1.5, 1.5], + [1.5, 1.5, 0.5], + [1.5, 0.5, 0.5], + ] + ), + ) + + +def test_box2d(dummy_box2d) -> None: + """Test `Box2D` class.""" + # test properties + assert dummy_box2d.offset == (100, 100) + assert dummy_box2d.size == (50, 50) + assert dummy_box2d.width == 50 + assert dummy_box2d.height == 50 + assert dummy_box2d.center == (125, 125) + assert dummy_box2d.area == 2500 diff --git a/tests/dataclass/test_label.py b/tests/dataclass/test_label.py new file mode 100644 index 0000000..3bbdd47 --- /dev/null +++ b/tests/dataclass/test_label.py @@ -0,0 +1,102 @@ +import pytest + +from t4_devkit.dataclass.label import LabelID, convert_label + + +@pytest.mark.parametrize( + ("labels", "expect"), + [ + # === object === + # car + ( + ( + "car", + "vehicle.car", + "vehicle.construction", + "vehicle.emergency (ambulance & police)", + "vehicle.police", + ), + LabelID.CAR, + ), + # truck + (("truck", "vehicle.truck", "trailer", "vehicle.trailer"), LabelID.TRUCK), + # bus + (("bus", "vehicle.bus", "vehicle.bus (bendy & rigid)"), LabelID.BUS), + # bicycle + (("bicycle", "vehicle.bicycle"), LabelID.BICYCLE), + # motorbike + ( + ("motorbike", "vehicle.motorbike", "motorcycle", "vehicle.motorcycle"), + LabelID.MOTORBIKE, + ), + # pedestrian + ( + ( + "pedestrian", + "pedestrian.child", + "pedestrian.personal_mobility", + "pedestrian.police_officer", + "pedestrian.stroller", + "pedestrian.wheelchair", + "construction_worker", + ), + LabelID.PEDESTRIAN, + ), + # animal + ("animal", LabelID.ANIMAL), + # unknown + ( + ( + "movable_object.barrier", + "movable_object.debris", + "movable_object.pushable_pullable", + "movable_object.trafficcone", + "movable_object.traffic_cone", + "static_object.bicycle_lack", + "static_object.bollard", + "forklift", + ), + LabelID.UNKNOWN, + ), + # === traffic light === + # GREEN + (("green", "crosswalk_green"), LabelID.GREEN), + ("green_straight", LabelID.GREEN_STRAIGHT), + ("green_left", LabelID.GREEN_LEFT), + ("green_right", LabelID.GREEN_RIGHT), + # YELLOW + ("yellow", LabelID.YELLOW), + ("yellow_straight", LabelID.YELLOW_STRAIGHT), + ("yellow_left", LabelID.YELLOW_LEFT), + ("yellow_right", LabelID.YELLOW_RIGHT), + ("yellow_straight_left", LabelID.YELLOW_STRAIGHT_LEFT), + ("yellow_straight_right", LabelID.YELLOW_STRAIGHT_RIGHT), + ("yellow_straight_left_right", LabelID.YELLOW_STRAIGHT_LEFT_RIGHT), + # RED + (("red", "crosswalk_red"), LabelID.RED), + ("red_straight", LabelID.RED_STRAIGHT), + ("red_left", LabelID.RED_LEFT), + ("red_right", LabelID.RED_RIGHT), + ("red_straight_left", LabelID.RED_STRAIGHT_LEFT), + ("red_straight_right", LabelID.RED_STRAIGHT_RIGHT), + ("red_straight_left_right", LabelID.RED_STRAIGHT_LEFT_RIGHT), + ( + ("red_straight_left_diagonal", "red_straight_leftdiagonal"), + LabelID.RED_LEFT_DIAGONAL, + ), + ( + ("red_straight_right_diagonal", "red_straight_rightdiagonal"), + LabelID.RED_RIGHT_DIAGONAL, + ), + # unknown traffic light + (("unknown", "crosswalk_unknown"), LabelID.UNKNOWN), + ], +) +def test_convert_label(labels: str | tuple[str, ...], expect: LabelID) -> None: + if isinstance(labels, str): + labels = [labels] + + for original in labels: + ret = convert_label(original) + assert ret.label == expect + assert ret.original == original diff --git a/tests/dataclass/test_roi.py b/tests/dataclass/test_roi.py new file mode 100644 index 0000000..8da5d18 --- /dev/null +++ b/tests/dataclass/test_roi.py @@ -0,0 +1,13 @@ +from t4_devkit.dataclass.roi import Roi + + +def test_roi() -> None: + # list item is converted to tuple internally + roi = Roi(roi=[10, 20, 30, 40]) + + assert roi.offset == (10, 20) + assert roi.size == (30, 40) + assert roi.width == 30 + assert roi.height == 40 + assert roi.center == (25, 40) + assert roi.area == 1200 diff --git a/tests/dataclass/test_shape.py b/tests/dataclass/test_shape.py new file mode 100644 index 0000000..10b6509 --- /dev/null +++ b/tests/dataclass/test_shape.py @@ -0,0 +1,23 @@ +import pytest + +from t4_devkit.dataclass.shape import ShapeType + + +def test_shape_type() -> None: + # test lower case + bbox1 = ShapeType.from_name("bounding_box") + assert bbox1 == ShapeType.BOUNDING_BOX + + polygon1 = ShapeType.from_name("polygon") + assert polygon1 == ShapeType.POLYGON + + # test upper case + bbox2 = ShapeType.from_name("BOUNDING_BOX") + assert bbox2 == ShapeType.BOUNDING_BOX + + polygon1 = ShapeType.from_name("POLYGON") + assert polygon1 == ShapeType.POLYGON + + # test exception + with pytest.raises(AssertionError): + ShapeType.from_name("FOO") diff --git a/tests/dataclass/test_trajectory.py b/tests/dataclass/test_trajectory.py new file mode 100644 index 0000000..0fdb69c --- /dev/null +++ b/tests/dataclass/test_trajectory.py @@ -0,0 +1,56 @@ +import numpy as np +import pytest + +from t4_devkit.dataclass.trajectory import Trajectory, to_trajectories + + +def test_trajectory() -> None: + """Test `Trajectory` class including its initialization and methods.""" + # list item is converted to NDArray internally + trajectory = Trajectory( + waypoints=[[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]], + confidence=1.0, + ) + + assert trajectory.confidence == 1.0 + + # test __len__() + assert len(trajectory) == 2 + + # test __getitem__() + assert np.allclose(trajectory[0], [1.0, 1.0, 1.0]) + assert np.allclose(trajectory[1], [2.0, 2.0, 2.0]) + + # test __iter__() + for point in trajectory: + assert isinstance(point, np.ndarray) + assert point.shape == (3,) + + # test shape property + assert trajectory.shape == (2, 3) + + +def test_to_trajectories() -> None: + """Test `to_trajectories` function including its valid and invalid cases.""" + # valid case + trajectories = to_trajectories( + waypoints=[ + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], # mode0 + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], # mode1 + ], + confidences=[ + 1.0, # mode0 + 2.0, # mode1 + ], + ) + assert len(trajectories) == 2 + + # invalid case: different element length + with pytest.raises(ValueError): + _ = to_trajectories( + waypoints=[ + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], + [[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], + ], + confidences=[1.0], + )