Skip to content

Commit

Permalink
refactor: replace dataclasses to attrs
Browse files Browse the repository at this point in the history
Signed-off-by: ktro2828 <[email protected]>
  • Loading branch information
ktro2828 committed Nov 11, 2024
1 parent 042d7bf commit eda726a
Show file tree
Hide file tree
Showing 28 changed files with 242 additions and 452 deletions.
52 changes: 52 additions & 0 deletions t4_devkit/common/converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from __future__ import annotations

from typing import TYPE_CHECKING, overload

import numpy as np
from pyquaternion import Quaternion

if TYPE_CHECKING:
from t4_devkit.typing import ArrayLike, NDArray

__all__ = ("as_array", "as_quaternion")


@overload
def as_array(value: ArrayLike) -> NDArray:
"""Covert array like object to numpy array."""
...


@overload
def as_array(value: None) -> None:
"""Do nothing if the input value is None."""
...


def as_array(value: ArrayLike | None) -> NDArray | None:
"""Convert input array to `NDArray`.
Note that, it returns `None` if the input is `None`.
Args:
value (ArrayLike | None): Array or None.
Returns:
NDArray | None: Numpy array or None.
"""
return np.asarray(value) if value is not None else None


def as_quaternion(value: ArrayLike | NDArray) -> Quaternion:
"""Convert input rotation like array to `Quaternion`.
Args:
value (ArrayLike | NDArray): Rotation matrix or quaternion.
Returns:
Quaternion: Converted instance.
"""
return (
Quaternion(matrix=value)
if isinstance(value, np.ndarray) and value.ndim == 2
else Quaternion(value)
)
35 changes: 11 additions & 24 deletions t4_devkit/dataclass/box.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypeVar

import numpy as np
from pyquaternion import Quaternion
from attrs import define, field
from shapely.geometry import Polygon
from typing_extensions import Self

from t4_devkit.common.converter import as_array, as_quaternion

from .roi import Roi
from .trajectory import to_trajectories

Expand Down Expand Up @@ -57,7 +58,7 @@ def distance_box(box: BoxType, tf_matrix: HomogeneousMatrix) -> float | None:
return np.linalg.norm(position)


@dataclass(eq=False)
@define(eq=False)
class BaseBox:
"""Abstract base class for box objects."""

Expand All @@ -72,7 +73,7 @@ class BaseBox:
# >>> e.g.) box.as_state() -> BoxState


@dataclass(eq=False)
@define(eq=False)
class Box3D(BaseBox):
"""A class to represent 3D box.
Expand Down Expand Up @@ -109,25 +110,15 @@ class Box3D(BaseBox):
... )
"""

position: TranslationType
rotation: RotationType
position: TranslationType = field(converter=as_array)
rotation: RotationType = field(converter=as_quaternion)
shape: Shape
velocity: VelocityType | None = field(default=None)
velocity: VelocityType | None = field(default=None, converter=as_array)
num_points: int | None = field(default=None)

# additional attributes: set by `with_**`
future: list[Trajectory] | None = field(default=None, init=False)

def __post_init__(self) -> None:
if not isinstance(self.position, np.ndarray):
self.position = np.array(self.position)

if not isinstance(self.rotation, Quaternion):
self.rotation = Quaternion(self.rotation)

if self.velocity is not None and not isinstance(self.velocity, np.ndarray):
self.velocity = np.array(self.velocity)

def with_future(
self,
waypoints: list[TrajectoryType],
Expand Down Expand Up @@ -195,7 +186,7 @@ def corners(self, box_scale: float = 1.0) -> NDArrayF64:
return np.dot(self.rotation.rotation_matrix, corners).T + self.position


@dataclass(eq=False)
@define(eq=False)
class Box2D(BaseBox):
"""A class to represent 2D box.
Expand All @@ -222,15 +213,11 @@ class Box2D(BaseBox):
>>> box2d = box2d.with_position(position=(1.0, 1.0, 1.0))
"""

roi: Roi | None = field(default=None)
roi: Roi | None = field(default=None, converter=lambda x: None if x is None else Roi(x))

# additional attributes: set by `with_**`
position: TranslationType | None = field(default=None, init=False)

def __post_init__(self) -> None:
if self.roi is not None and not isinstance(self.roi, Roi):
self.roi = Roi(self.roi)

def with_position(self, position: TranslationType) -> Self:
"""Return a self instance setting `position` attribute.
Expand All @@ -240,7 +227,7 @@ def with_position(self, position: TranslationType) -> Self:
Returns:
Self instance after setting `position`.
"""
self.position = np.array(position) if not isinstance(position, np.ndarray) else position
self.position = as_array(position)
return self

def __eq__(self, other: Box2D | None) -> bool:
Expand Down
6 changes: 3 additions & 3 deletions t4_devkit/dataclass/label.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from enum import Enum, auto, unique

from attrs import define, field
from typing_extensions import Self

__all__ = ["LabelID", "SemanticLabel", "convert_label"]
Expand Down Expand Up @@ -58,7 +58,7 @@ def __eq__(self, other: str | LabelID) -> bool:
return self.name == other.upper() if isinstance(other, str) else self.name == other.name


@dataclass(frozen=True, eq=False)
@define(frozen=True, eq=False)
class SemanticLabel:
"""A dataclass to represent semantic labels.
Expand All @@ -70,7 +70,7 @@ class SemanticLabel:

label: LabelID
original: str | None = field(default=None)
attributes: list[str] = field(default_factory=list)
attributes: list[str] = field(factory=list)

def __eq__(self, other: str | SemanticLabel) -> bool:
return self.label == other if isinstance(other, str) else self.label == other.label
Expand Down
24 changes: 15 additions & 9 deletions t4_devkit/dataclass/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

import struct
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, TypeVar

import numpy as np
from attrs import define, field

from t4_devkit.common.converter import as_array

if TYPE_CHECKING:
from typing_extensions import Self
Expand All @@ -21,14 +23,18 @@
]


@dataclass
@define
class PointCloud:
"""Abstract base dataclass for pointcloud data."""

points: NDArrayFloat
points: NDArrayFloat = field(converter=as_array)

def __post_init__(self) -> None:
assert self.points.shape[0] == self.num_dims()
@points.validator
def check_dims(self, attribute, value) -> None:
if value.shape[0] != self.num_dims():
raise ValueError(
f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}"
)

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -74,7 +80,7 @@ def transform(self, matrix: NDArrayFloat) -> None:
)[:3, :]


@dataclass
@define
class LidarPointCloud(PointCloud):
"""A dataclass to represent lidar pointcloud."""

Expand All @@ -91,7 +97,7 @@ def from_file(cls, filepath: str) -> Self:
return cls(points.T)


@dataclass
@define
class RadarPointCloud(PointCloud):
# class variables
invalid_states: ClassVar[list[int]] = [0]
Expand Down Expand Up @@ -188,9 +194,9 @@ def from_file(
return cls(points)


@dataclass
@define
class SegmentationPointCloud(PointCloud):
labels: NDArrayU8
labels: NDArrayU8 = field(converter=as_array)

@staticmethod
def num_dims() -> int:
Expand Down
10 changes: 4 additions & 6 deletions t4_devkit/dataclass/roi.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,24 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from attrs import define, field

if TYPE_CHECKING:
from t4_devkit.typing import RoiType

__all__ = ["Roi"]


@dataclass
@define
class Roi:
roi: RoiType
roi: RoiType = field(converter=tuple)

def __post_init__(self) -> None:
assert len(self.roi) == 4, (
"Expected roi is (x, y, width, height), " f"but got length with {len(self.roi)}."
)

if not isinstance(self.roi, tuple):
self.roi = tuple(self.roi)

@property
def offset(self) -> tuple[int, int]:
return self.roi[:2]
Expand Down
13 changes: 6 additions & 7 deletions t4_devkit/dataclass/shape.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from dataclasses import dataclass, field
from enum import Enum, auto, unique
from typing import TYPE_CHECKING

import numpy as np
from attrs import define, field
from shapely.geometry import Polygon
from typing_extensions import Self

from t4_devkit.common.converter import as_array

if TYPE_CHECKING:
from t4_devkit.typing import NDArrayF64, SizeType

Expand Down Expand Up @@ -35,7 +37,7 @@ def from_name(cls, name: str) -> Self:
return cls.__members__[name]


@dataclass
@define
class Shape:
"""A dataclass to represent the 3D box shape.
Expand All @@ -47,13 +49,10 @@ class Shape:
"""

shape_type: ShapeType
size: SizeType
size: SizeType = field(converter=as_array)
footprint: Polygon = field(default=None)

def __post_init__(self) -> None:
if not isinstance(self.size, np.ndarray):
self.size = np.array(self.size)

def __attrs_post_init__(self) -> None:
if self.shape_type == ShapeType.POLYGON and self.footprint is None:
raise ValueError("`footprint` must be specified for `POLYGON`.")

Expand Down
17 changes: 8 additions & 9 deletions t4_devkit/dataclass/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generator

import numpy as np
from attrs import define, field

from t4_devkit.common.converter import as_array

if TYPE_CHECKING:
from t4_devkit.typing import TrajectoryType, TranslationType

__all__ = ["Trajectory", "to_trajectories"]


@dataclass
@define
class Trajectory:
"""A dataclass to represent trajectory.
Expand Down Expand Up @@ -41,14 +42,12 @@ class Trajectory:
[2. 2. 2.]
"""

waypoints: TrajectoryType
waypoints: TrajectoryType = field(converter=as_array)
confidence: float = field(default=1.0)

def __post_init__(self) -> None:
if not isinstance(self.waypoints, np.ndarray):
self.waypoints = np.array(self.waypoints)

assert self.waypoints.shape[1] == 3
@waypoints.validator
def check_dims(self, attribute, value) -> None:
assert value.shape[1] == 3

def __len__(self) -> int:
return len(self.waypoints)
Expand Down
Loading

0 comments on commit eda726a

Please sign in to comment.