Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: replace dataclasses to attrs #27

Merged
merged 1 commit into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions t4_devkit/common/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from pyquaternion import Quaternion

if TYPE_CHECKING:
from t4_devkit.typing import ArrayLike, NDArray

__all__ = ["as_quaternion"]


def as_quaternion(value: ArrayLike | NDArray) -> Quaternion:
"""Convert input rotation like array to `Quaternion`.

Args:
value (ArrayLike | NDArray): Rotation matrix or quaternion.

Returns:
Quaternion: Converted instance.
"""
return (
Quaternion(matrix=value)
if isinstance(value, np.ndarray) and value.ndim == 2
else Quaternion(value)
)
36 changes: 12 additions & 24 deletions t4_devkit/dataclass/box.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypeVar

import numpy as np
from pyquaternion import Quaternion
from attrs import define, field
from attrs.converters import optional
from shapely.geometry import Polygon
from typing_extensions import Self

from t4_devkit.common.converter import as_quaternion

from .roi import Roi
from .trajectory import to_trajectories

Expand Down Expand Up @@ -57,7 +59,7 @@ def distance_box(box: BoxType, tf_matrix: HomogeneousMatrix) -> float | None:
return np.linalg.norm(position)


@dataclass(eq=False)
@define(eq=False)
class BaseBox:
"""Abstract base class for box objects."""

Expand All @@ -72,7 +74,7 @@ class BaseBox:
# >>> e.g.) box.as_state() -> BoxState


@dataclass(eq=False)
@define(eq=False)
class Box3D(BaseBox):
"""A class to represent 3D box.

Expand Down Expand Up @@ -109,25 +111,15 @@ class Box3D(BaseBox):
... )
"""

position: TranslationType
rotation: RotationType
position: TranslationType = field(converter=np.asarray)
rotation: RotationType = field(converter=as_quaternion)
shape: Shape
velocity: VelocityType | None = field(default=None)
velocity: VelocityType | None = field(default=None, converter=optional(np.asarray))
num_points: int | None = field(default=None)

# additional attributes: set by `with_**`
future: list[Trajectory] | None = field(default=None, init=False)

def __post_init__(self) -> None:
if not isinstance(self.position, np.ndarray):
self.position = np.array(self.position)

if not isinstance(self.rotation, Quaternion):
self.rotation = Quaternion(self.rotation)

if self.velocity is not None and not isinstance(self.velocity, np.ndarray):
self.velocity = np.array(self.velocity)

def with_future(
self,
waypoints: list[TrajectoryType],
Expand Down Expand Up @@ -195,7 +187,7 @@ def corners(self, box_scale: float = 1.0) -> NDArrayF64:
return np.dot(self.rotation.rotation_matrix, corners).T + self.position


@dataclass(eq=False)
@define(eq=False)
class Box2D(BaseBox):
"""A class to represent 2D box.

Expand All @@ -222,15 +214,11 @@ class Box2D(BaseBox):
>>> box2d = box2d.with_position(position=(1.0, 1.0, 1.0))
"""

roi: Roi | None = field(default=None)
roi: Roi | None = field(default=None, converter=lambda x: None if x is None else Roi(x))

# additional attributes: set by `with_**`
position: TranslationType | None = field(default=None, init=False)

def __post_init__(self) -> None:
if self.roi is not None and not isinstance(self.roi, Roi):
self.roi = Roi(self.roi)

def with_position(self, position: TranslationType) -> Self:
"""Return a self instance setting `position` attribute.

Expand All @@ -240,7 +228,7 @@ def with_position(self, position: TranslationType) -> Self:
Returns:
Self instance after setting `position`.
"""
self.position = np.array(position) if not isinstance(position, np.ndarray) else position
self.position = np.asarray(position)
return self

def __eq__(self, other: Box2D | None) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions t4_devkit/dataclass/label.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from dataclasses import dataclass, field
from attrs import define, field

__all__ = ["SemanticLabel"]


@dataclass(frozen=True, eq=False)
@define(frozen=True, eq=False)
class SemanticLabel:
"""A dataclass to represent semantic labels.

Expand All @@ -15,7 +15,7 @@ class SemanticLabel:
"""

name: str
attributes: list[str] = field(default_factory=list)
attributes: list[str] = field(factory=list)

def __eq__(self, other: str | SemanticLabel) -> bool:
return self.name == other if isinstance(other, str) else self.name == other.name
22 changes: 13 additions & 9 deletions t4_devkit/dataclass/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

import struct
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, TypeVar

import numpy as np
from attrs import define, field

if TYPE_CHECKING:
from typing_extensions import Self
Expand All @@ -21,14 +21,18 @@
]


@dataclass
@define
class PointCloud:
"""Abstract base dataclass for pointcloud data."""

points: NDArrayFloat
points: NDArrayFloat = field(converter=np.asarray)

def __post_init__(self) -> None:
assert self.points.shape[0] == self.num_dims()
@points.validator
def check_dims(self, attribute, value) -> None:
if value.shape[0] != self.num_dims():
raise ValueError(
f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}"
)

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -74,7 +78,7 @@ def transform(self, matrix: NDArrayFloat) -> None:
)[:3, :]


@dataclass
@define
class LidarPointCloud(PointCloud):
"""A dataclass to represent lidar pointcloud."""

Expand All @@ -91,7 +95,7 @@ def from_file(cls, filepath: str) -> Self:
return cls(points.T)


@dataclass
@define
class RadarPointCloud(PointCloud):
# class variables
invalid_states: ClassVar[list[int]] = [0]
Expand Down Expand Up @@ -188,9 +192,9 @@ def from_file(
return cls(points)


@dataclass
@define
class SegmentationPointCloud(PointCloud):
labels: NDArrayU8
labels: NDArrayU8 = field(converter=lambda x: np.asarray(x, dtype=np.uint8))

@staticmethod
def num_dims() -> int:
Expand Down
10 changes: 4 additions & 6 deletions t4_devkit/dataclass/roi.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.typing import RoiType

__all__ = ["Roi"]


@dataclass
@define
class Roi:
roi: RoiType
roi: RoiType = field(converter=tuple)

def __post_init__(self) -> None:
assert len(self.roi) == 4, (
"Expected roi is (x, y, width, height), " f"but got length with {len(self.roi)}."
)

if not isinstance(self.roi, tuple):
self.roi = tuple(self.roi)

@property
def offset(self) -> tuple[int, int]:
return self.roi[:2]
Expand Down
11 changes: 4 additions & 7 deletions t4_devkit/dataclass/shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum, auto, unique
from typing import TYPE_CHECKING

import numpy as np
from attrs import define, field
from shapely.geometry import Polygon
from typing_extensions import Self

Expand Down Expand Up @@ -35,7 +35,7 @@ def from_name(cls, name: str) -> Self:
return cls.__members__[name]


@dataclass
@define
class Shape:
"""A dataclass to represent the 3D box shape.

Expand All @@ -47,13 +47,10 @@ class Shape:
"""

shape_type: ShapeType
size: SizeType
size: SizeType = field(converter=np.asarray)
footprint: Polygon = field(default=None)

def __post_init__(self) -> None:
if not isinstance(self.size, np.ndarray):
self.size = np.array(self.size)

def __attrs_post_init__(self) -> None:
if self.shape_type == ShapeType.POLYGON and self.footprint is None:
raise ValueError("`footprint` must be specified for `POLYGON`.")

Expand Down
14 changes: 6 additions & 8 deletions t4_devkit/dataclass/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generator

import numpy as np
from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.typing import TrajectoryType, TranslationType

__all__ = ["Trajectory", "to_trajectories"]


@dataclass
@define
class Trajectory:
"""A dataclass to represent trajectory.

Expand Down Expand Up @@ -41,14 +41,12 @@ class Trajectory:
[2. 2. 2.]
"""

waypoints: TrajectoryType
waypoints: TrajectoryType = field(converter=np.asarray)
confidence: float = field(default=1.0)

def __post_init__(self) -> None:
if not isinstance(self.waypoints, np.ndarray):
self.waypoints = np.array(self.waypoints)

assert self.waypoints.shape[1] == 3
def __attrs_post_init__(self) -> None:
if self.waypoints.shape[1] != 3:
raise ValueError("Trajectory dimension must be 3.")

def __len__(self) -> int:
return len(self.waypoints)
Expand Down
Loading
Loading