From eda726a8f6cda1b3baf7c98535eb0af26c8e62b9 Mon Sep 17 00:00:00 2001 From: ktro2828 Date: Mon, 11 Nov 2024 23:30:07 +0900 Subject: [PATCH] refactor: replace `dataclasses` to `attrs` Signed-off-by: ktro2828 --- t4_devkit/common/converter.py | 52 +++++++++++++++++ t4_devkit/dataclass/box.py | 35 ++++------- t4_devkit/dataclass/label.py | 6 +- t4_devkit/dataclass/pointcloud.py | 24 +++++--- t4_devkit/dataclass/roi.py | 10 ++-- t4_devkit/dataclass/shape.py | 13 ++--- t4_devkit/dataclass/trajectory.py | 17 +++--- t4_devkit/dataclass/transform.py | 46 +++++---------- t4_devkit/filtering/parameter.py | 4 +- t4_devkit/schema/serialize.py | 2 +- t4_devkit/schema/tables/attribute.py | 16 ++--- t4_devkit/schema/tables/base.py | 19 +++--- t4_devkit/schema/tables/calibrated_sensor.py | 49 ++++------------ t4_devkit/schema/tables/category.py | 14 +---- t4_devkit/schema/tables/ego_pose.py | 29 +++------- t4_devkit/schema/tables/instance.py | 16 ++--- t4_devkit/schema/tables/keypoint.py | 35 +++-------- t4_devkit/schema/tables/log.py | 18 +++--- t4_devkit/schema/tables/map.py | 16 ++--- t4_devkit/schema/tables/object_ann.py | 25 ++++---- t4_devkit/schema/tables/registry.py | 3 +- t4_devkit/schema/tables/sample.py | 22 +++---- t4_devkit/schema/tables/sample_annotation.py | 60 ++++--------------- t4_devkit/schema/tables/sample_data.py | 61 +++----------------- t4_devkit/schema/tables/scene.py | 16 ++--- t4_devkit/schema/tables/sensor.py | 33 ++--------- t4_devkit/schema/tables/surface_ann.py | 18 ++---- t4_devkit/schema/tables/visibility.py | 35 ++++------- 28 files changed, 242 insertions(+), 452 deletions(-) create mode 100644 t4_devkit/common/converter.py diff --git a/t4_devkit/common/converter.py b/t4_devkit/common/converter.py new file mode 100644 index 0000000..5f7cdb6 --- /dev/null +++ b/t4_devkit/common/converter.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, overload + +import numpy as np +from pyquaternion import Quaternion + +if TYPE_CHECKING: + from t4_devkit.typing import ArrayLike, NDArray + +__all__ = ("as_array", "as_quaternion") + + +@overload +def as_array(value: ArrayLike) -> NDArray: + """Covert array like object to numpy array.""" + ... + + +@overload +def as_array(value: None) -> None: + """Do nothing if the input value is None.""" + ... + + +def as_array(value: ArrayLike | None) -> NDArray | None: + """Convert input array to `NDArray`. + Note that, it returns `None` if the input is `None`. + + Args: + value (ArrayLike | None): Array or None. + + Returns: + NDArray | None: Numpy array or None. + """ + return np.asarray(value) if value is not None else None + + +def as_quaternion(value: ArrayLike | NDArray) -> Quaternion: + """Convert input rotation like array to `Quaternion`. + + Args: + value (ArrayLike | NDArray): Rotation matrix or quaternion. + + Returns: + Quaternion: Converted instance. + """ + return ( + Quaternion(matrix=value) + if isinstance(value, np.ndarray) and value.ndim == 2 + else Quaternion(value) + ) diff --git a/t4_devkit/dataclass/box.py b/t4_devkit/dataclass/box.py index f7b6729..5d5ea01 100644 --- a/t4_devkit/dataclass/box.py +++ b/t4_devkit/dataclass/box.py @@ -1,13 +1,14 @@ from __future__ import annotations -from dataclasses import dataclass, field from typing import TYPE_CHECKING, TypeVar import numpy as np -from pyquaternion import Quaternion +from attrs import define, field from shapely.geometry import Polygon from typing_extensions import Self +from t4_devkit.common.converter import as_array, as_quaternion + from .roi import Roi from .trajectory import to_trajectories @@ -57,7 +58,7 @@ def distance_box(box: BoxType, tf_matrix: HomogeneousMatrix) -> float | None: return np.linalg.norm(position) -@dataclass(eq=False) +@define(eq=False) class BaseBox: """Abstract base class for box objects.""" @@ -72,7 +73,7 @@ class BaseBox: # >>> e.g.) box.as_state() -> BoxState -@dataclass(eq=False) +@define(eq=False) class Box3D(BaseBox): """A class to represent 3D box. @@ -109,25 +110,15 @@ class Box3D(BaseBox): ... ) """ - position: TranslationType - rotation: RotationType + position: TranslationType = field(converter=as_array) + rotation: RotationType = field(converter=as_quaternion) shape: Shape - velocity: VelocityType | None = field(default=None) + velocity: VelocityType | None = field(default=None, converter=as_array) num_points: int | None = field(default=None) # additional attributes: set by `with_**` future: list[Trajectory] | None = field(default=None, init=False) - def __post_init__(self) -> None: - if not isinstance(self.position, np.ndarray): - self.position = np.array(self.position) - - if not isinstance(self.rotation, Quaternion): - self.rotation = Quaternion(self.rotation) - - if self.velocity is not None and not isinstance(self.velocity, np.ndarray): - self.velocity = np.array(self.velocity) - def with_future( self, waypoints: list[TrajectoryType], @@ -195,7 +186,7 @@ def corners(self, box_scale: float = 1.0) -> NDArrayF64: return np.dot(self.rotation.rotation_matrix, corners).T + self.position -@dataclass(eq=False) +@define(eq=False) class Box2D(BaseBox): """A class to represent 2D box. @@ -222,15 +213,11 @@ class Box2D(BaseBox): >>> box2d = box2d.with_position(position=(1.0, 1.0, 1.0)) """ - roi: Roi | None = field(default=None) + roi: Roi | None = field(default=None, converter=lambda x: None if x is None else Roi(x)) # additional attributes: set by `with_**` position: TranslationType | None = field(default=None, init=False) - def __post_init__(self) -> None: - if self.roi is not None and not isinstance(self.roi, Roi): - self.roi = Roi(self.roi) - def with_position(self, position: TranslationType) -> Self: """Return a self instance setting `position` attribute. @@ -240,7 +227,7 @@ def with_position(self, position: TranslationType) -> Self: Returns: Self instance after setting `position`. """ - self.position = np.array(position) if not isinstance(position, np.ndarray) else position + self.position = as_array(position) return self def __eq__(self, other: Box2D | None) -> bool: diff --git a/t4_devkit/dataclass/label.py b/t4_devkit/dataclass/label.py index ab5c2e2..dbcbe31 100644 --- a/t4_devkit/dataclass/label.py +++ b/t4_devkit/dataclass/label.py @@ -1,9 +1,9 @@ from __future__ import annotations import warnings -from dataclasses import dataclass, field from enum import Enum, auto, unique +from attrs import define, field from typing_extensions import Self __all__ = ["LabelID", "SemanticLabel", "convert_label"] @@ -58,7 +58,7 @@ def __eq__(self, other: str | LabelID) -> bool: return self.name == other.upper() if isinstance(other, str) else self.name == other.name -@dataclass(frozen=True, eq=False) +@define(frozen=True, eq=False) class SemanticLabel: """A dataclass to represent semantic labels. @@ -70,7 +70,7 @@ class SemanticLabel: label: LabelID original: str | None = field(default=None) - attributes: list[str] = field(default_factory=list) + attributes: list[str] = field(factory=list) def __eq__(self, other: str | SemanticLabel) -> bool: return self.label == other if isinstance(other, str) else self.label == other.label diff --git a/t4_devkit/dataclass/pointcloud.py b/t4_devkit/dataclass/pointcloud.py index 97a5fea..06147c3 100644 --- a/t4_devkit/dataclass/pointcloud.py +++ b/t4_devkit/dataclass/pointcloud.py @@ -2,10 +2,12 @@ import struct from abc import abstractmethod -from dataclasses import dataclass from typing import TYPE_CHECKING, ClassVar, TypeVar import numpy as np +from attrs import define, field + +from t4_devkit.common.converter import as_array if TYPE_CHECKING: from typing_extensions import Self @@ -21,14 +23,18 @@ ] -@dataclass +@define class PointCloud: """Abstract base dataclass for pointcloud data.""" - points: NDArrayFloat + points: NDArrayFloat = field(converter=as_array) - def __post_init__(self) -> None: - assert self.points.shape[0] == self.num_dims() + @points.validator + def check_dims(self, attribute, value) -> None: + if value.shape[0] != self.num_dims(): + raise ValueError( + f"Expected point dimension is {self.num_dims()}, but got {value.shape[0]}" + ) @staticmethod @abstractmethod @@ -74,7 +80,7 @@ def transform(self, matrix: NDArrayFloat) -> None: )[:3, :] -@dataclass +@define class LidarPointCloud(PointCloud): """A dataclass to represent lidar pointcloud.""" @@ -91,7 +97,7 @@ def from_file(cls, filepath: str) -> Self: return cls(points.T) -@dataclass +@define class RadarPointCloud(PointCloud): # class variables invalid_states: ClassVar[list[int]] = [0] @@ -188,9 +194,9 @@ def from_file( return cls(points) -@dataclass +@define class SegmentationPointCloud(PointCloud): - labels: NDArrayU8 + labels: NDArrayU8 = field(converter=as_array) @staticmethod def num_dims() -> int: diff --git a/t4_devkit/dataclass/roi.py b/t4_devkit/dataclass/roi.py index f9e4a5f..597e17b 100644 --- a/t4_devkit/dataclass/roi.py +++ b/t4_devkit/dataclass/roi.py @@ -1,26 +1,24 @@ from __future__ import annotations -from dataclasses import dataclass from typing import TYPE_CHECKING +from attrs import define, field + if TYPE_CHECKING: from t4_devkit.typing import RoiType __all__ = ["Roi"] -@dataclass +@define class Roi: - roi: RoiType + roi: RoiType = field(converter=tuple) def __post_init__(self) -> None: assert len(self.roi) == 4, ( "Expected roi is (x, y, width, height), " f"but got length with {len(self.roi)}." ) - if not isinstance(self.roi, tuple): - self.roi = tuple(self.roi) - @property def offset(self) -> tuple[int, int]: return self.roi[:2] diff --git a/t4_devkit/dataclass/shape.py b/t4_devkit/dataclass/shape.py index 9fe9e08..2823584 100644 --- a/t4_devkit/dataclass/shape.py +++ b/t4_devkit/dataclass/shape.py @@ -1,13 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass, field from enum import Enum, auto, unique from typing import TYPE_CHECKING import numpy as np +from attrs import define, field from shapely.geometry import Polygon from typing_extensions import Self +from t4_devkit.common.converter import as_array + if TYPE_CHECKING: from t4_devkit.typing import NDArrayF64, SizeType @@ -35,7 +37,7 @@ def from_name(cls, name: str) -> Self: return cls.__members__[name] -@dataclass +@define class Shape: """A dataclass to represent the 3D box shape. @@ -47,13 +49,10 @@ class Shape: """ shape_type: ShapeType - size: SizeType + size: SizeType = field(converter=as_array) footprint: Polygon = field(default=None) - def __post_init__(self) -> None: - if not isinstance(self.size, np.ndarray): - self.size = np.array(self.size) - + def __attrs_post_init__(self) -> None: if self.shape_type == ShapeType.POLYGON and self.footprint is None: raise ValueError("`footprint` must be specified for `POLYGON`.") diff --git a/t4_devkit/dataclass/trajectory.py b/t4_devkit/dataclass/trajectory.py index ff013bc..ce27fa6 100644 --- a/t4_devkit/dataclass/trajectory.py +++ b/t4_devkit/dataclass/trajectory.py @@ -1,9 +1,10 @@ from __future__ import annotations -from dataclasses import dataclass, field from typing import TYPE_CHECKING, Generator -import numpy as np +from attrs import define, field + +from t4_devkit.common.converter import as_array if TYPE_CHECKING: from t4_devkit.typing import TrajectoryType, TranslationType @@ -11,7 +12,7 @@ __all__ = ["Trajectory", "to_trajectories"] -@dataclass +@define class Trajectory: """A dataclass to represent trajectory. @@ -41,14 +42,12 @@ class Trajectory: [2. 2. 2.] """ - waypoints: TrajectoryType + waypoints: TrajectoryType = field(converter=as_array) confidence: float = field(default=1.0) - def __post_init__(self) -> None: - if not isinstance(self.waypoints, np.ndarray): - self.waypoints = np.array(self.waypoints) - - assert self.waypoints.shape[1] == 3 + @waypoints.validator + def check_dims(self, attribute, value) -> None: + assert value.shape[1] == 3 def __len__(self) -> int: return len(self.waypoints) diff --git a/t4_devkit/dataclass/transform.py b/t4_devkit/dataclass/transform.py index 8b5863c..e1bd4fc 100644 --- a/t4_devkit/dataclass/transform.py +++ b/t4_devkit/dataclass/transform.py @@ -1,17 +1,18 @@ from __future__ import annotations from copy import deepcopy -from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, overload import numpy as np +from attrs import define, field from pyquaternion import Quaternion from typing_extensions import Self +from t4_devkit.common.converter import as_array, as_quaternion from t4_devkit.typing import NDArray, RotationType if TYPE_CHECKING: - from t4_devkit.typing import ArrayLike + from t4_devkit.typing import ArrayLike, TranslationType __all__ = [ "TransformBuffer", @@ -22,9 +23,9 @@ ] -@dataclass +@define class TransformBuffer: - buffer: dict[tuple[str, str], HomogeneousMatrix] = field(default_factory=dict, init=False) + buffer: dict[tuple[str, str], HomogeneousMatrix] = field(factory=dict, init=False) def set_transform(self, matrix: HomogeneousMatrix) -> None: """Set transform matrix to the buffer. @@ -59,35 +60,16 @@ def do_transform(self, src: str, dst: str, *args, **kwargs) -> TransformItemLike return tf_matrix.transform(*args, **kwargs) if tf_matrix is not None else None -@dataclass +@define class HomogeneousMatrix: - def __init__( - self, - position: ArrayLike, - rotation: ArrayLike | RotationType, - src: str, - dst: str, - ) -> None: - """Construct a new object. - - Args: - position (ArrayLike): 3D position. - rotation (ArrayLike | RotationType): 3x3 rotation matrix or quaternion. - src (str): Source frame ID. - dst (str): Destination frame ID. - """ - self.position = position if isinstance(position, np.ndarray) else np.array(position) - - if isinstance(rotation, np.ndarray) and rotation.ndim == 2: - rotation = Quaternion(matrix=rotation) - elif not isinstance(rotation, Quaternion): - rotation = Quaternion(rotation) - self.rotation = rotation - - self.src = src - self.dst = dst - - self.matrix = _generate_homogeneous_matrix(position, rotation) + position: TranslationType = field(converter=as_array) + rotation: Quaternion = field(converter=as_quaternion) + src: str + dst: str + matrix: NDArray = field(init=False) + + def __attrs_post_init__(self) -> None: + self.matrix = _generate_homogeneous_matrix(self.position, self.rotation) @classmethod def as_identity(cls, frame_id: str) -> Self: diff --git a/t4_devkit/filtering/parameter.py b/t4_devkit/filtering/parameter.py index 922aebb..4a6636b 100644 --- a/t4_devkit/filtering/parameter.py +++ b/t4_devkit/filtering/parameter.py @@ -1,15 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass, field from typing import TYPE_CHECKING, Sequence import numpy as np +from attrs import define, field if TYPE_CHECKING: from t4_devkit.dataclass import SemanticLabel -@dataclass +@define class FilterParams: """A dataclass to represent filtering parameters. diff --git a/t4_devkit/schema/serialize.py b/t4_devkit/schema/serialize.py index 46303fa..e8853af 100644 --- a/t4_devkit/schema/serialize.py +++ b/t4_devkit/schema/serialize.py @@ -1,11 +1,11 @@ from __future__ import annotations -from dataclasses import asdict from enum import Enum from functools import partial from typing import TYPE_CHECKING, Any, Sequence import numpy as np +from attrs import asdict from pyquaternion import Quaternion if TYPE_CHECKING: diff --git a/t4_devkit/schema/tables/attribute.py b/t4_devkit/schema/tables/attribute.py index 22d79cd..d0368b9 100644 --- a/t4_devkit/schema/tables/attribute.py +++ b/t4_devkit/schema/tables/attribute.py @@ -1,18 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any - -from typing_extensions import Self +from attrs import define +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName -__all__ = ("Attribute",) +__all__ = ["Attribute"] -@dataclass +@define @SCHEMAS.register(SchemaName.ATTRIBUTE) class Attribute(SchemaBase): """A dataclass to represent schema table of `attribute.json`. @@ -23,10 +20,5 @@ class Attribute(SchemaBase): description (str): Attribute description. """ - token: str name: str description: str - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - return cls(**data) diff --git a/t4_devkit/schema/tables/base.py b/t4_devkit/schema/tables/base.py index 546f98c..013049a 100644 --- a/t4_devkit/schema/tables/base.py +++ b/t4_devkit/schema/tables/base.py @@ -1,14 +1,16 @@ -from abc import ABC, abstractmethod -from dataclasses import dataclass +from __future__ import annotations + +from abc import ABC from typing import Any, TypeVar +from attrs import define + from t4_devkit.common.io import load_json -from typing_extensions import Self -__all__ = ("SchemaBase", "SchemaTable") +__all__ = ["SchemaBase", "SchemaTable"] -@dataclass +@define class SchemaBase(ABC): """Abstract base dataclass of schema tables.""" @@ -24,7 +26,7 @@ def shortcuts() -> tuple[str, ...] | None: return None @classmethod - def from_json(cls, filepath: str) -> list[Self]: + def from_json(cls, filepath: str) -> list[SchemaTable]: """Construct dataclass from json file. Args: @@ -37,8 +39,7 @@ def from_json(cls, filepath: str) -> list[Self]: return [cls.from_dict(data) for data in records] @classmethod - @abstractmethod - def from_dict(cls, data: dict[str, Any]) -> Self: + def from_dict(cls, data: dict[str, Any]) -> SchemaTable: """Construct dataclass from dict. Args: @@ -47,7 +48,7 @@ def from_dict(cls, data: dict[str, Any]) -> Self: Returns: Instantiated schema dataclass. """ - ... + return cls(**data) SchemaTable = TypeVar("SchemaTable", bound=SchemaBase) diff --git a/t4_devkit/schema/tables/calibrated_sensor.py b/t4_devkit/schema/tables/calibrated_sensor.py index 6e73ddc..43f6de3 100644 --- a/t4_devkit/schema/tables/calibrated_sensor.py +++ b/t4_devkit/schema/tables/calibrated_sensor.py @@ -1,28 +1,22 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -import numpy as np -from pyquaternion import Quaternion -from typing_extensions import Self +from attrs import define, field +from t4_devkit.common.converter import as_array, as_quaternion + +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName if TYPE_CHECKING: - from t4_devkit.typing import ( - CamDistortionType, - CamIntrinsicType, - RotationType, - TranslationType, - ) + from t4_devkit.typing import CamDistortionType, CamIntrinsicType, RotationType, TranslationType -__all__ = ("CalibratedSensor",) +__all__ = ["CalibratedSensor"] -@dataclass +@define @SCHEMAS.register(SchemaName.CALIBRATED_SENSOR) class CalibratedSensor(SchemaBase): """A dataclass to represent schema table of `calibrated_sensor.json`. @@ -36,27 +30,8 @@ class CalibratedSensor(SchemaBase): camera_distortion (CamDistortionType): Camera distortion array. Empty for sensors that are not cameras. """ - token: str sensor_token: str - translation: TranslationType - rotation: RotationType - camera_intrinsic: CamIntrinsicType - camera_distortion: CamDistortionType - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - token: str = data["token"] - sensor_token: str = data["sensor_token"] - translation = np.array(data["translation"]) - rotation = Quaternion(data["rotation"]) - camera_intrinsic = np.array(data["camera_intrinsic"]) - camera_distortion = np.array(data["camera_distortion"]) - - return cls( - token=token, - sensor_token=sensor_token, - translation=translation, - rotation=rotation, - camera_intrinsic=camera_intrinsic, - camera_distortion=camera_distortion, - ) + translation: TranslationType = field(converter=as_array) + rotation: RotationType = field(converter=as_quaternion) + camera_intrinsic: CamIntrinsicType = field(converter=as_array) + camera_distortion: CamDistortionType = field(converter=as_array) diff --git a/t4_devkit/schema/tables/category.py b/t4_devkit/schema/tables/category.py index 204013d..bfdca33 100644 --- a/t4_devkit/schema/tables/category.py +++ b/t4_devkit/schema/tables/category.py @@ -1,18 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any - -from typing_extensions import Self +from attrs import define +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName __all__ = ("Category",) -@dataclass +@define @SCHEMAS.register(SchemaName.CATEGORY) class Category(SchemaBase): """A dataclass to represent schema table of `category.json`. @@ -23,10 +20,5 @@ class Category(SchemaBase): description (str): Category description. """ - token: str name: str description: str - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - return cls(**data) diff --git a/t4_devkit/schema/tables/ego_pose.py b/t4_devkit/schema/tables/ego_pose.py index 6925378..8e85b87 100644 --- a/t4_devkit/schema/tables/ego_pose.py +++ b/t4_devkit/schema/tables/ego_pose.py @@ -1,23 +1,22 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -import numpy as np -from pyquaternion import Quaternion -from typing_extensions import Self +from attrs import define, field +from t4_devkit.common.converter import as_array, as_quaternion + +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName if TYPE_CHECKING: from t4_devkit.typing import RotationType, TranslationType -__all__ = ("EgoPose",) +__all__ = ["EgoPose"] -@dataclass +@define @SCHEMAS.register(SchemaName.EGO_POSE) class EgoPose(SchemaBase): """A dataclass to represent schema table of `ego_pose.json`. @@ -29,16 +28,6 @@ class EgoPose(SchemaBase): timestamp (int): Unix time stamp. """ - token: str - translation: TranslationType - rotation: RotationType + translation: TranslationType = field(converter=as_array) + rotation: RotationType = field(converter=as_quaternion) timestamp: int - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - token: str = data["token"] - translation = np.array(data["translation"]) - rotation = Quaternion(data["rotation"]) - timestamp: int = data["timestamp"] - - return cls(token=token, translation=translation, rotation=rotation, timestamp=timestamp) diff --git a/t4_devkit/schema/tables/instance.py b/t4_devkit/schema/tables/instance.py index 739d3ea..a100767 100644 --- a/t4_devkit/schema/tables/instance.py +++ b/t4_devkit/schema/tables/instance.py @@ -1,16 +1,13 @@ -from dataclasses import dataclass -from typing import Any - -from typing_extensions import Self +from attrs import define +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName -__all__ = ("Instance",) +__all__ = ["Instance"] -@dataclass +@define @SCHEMAS.register(SchemaName.INSTANCE) class Instance(SchemaBase): """A dataclass to represent schema table of `instance.json`. @@ -24,13 +21,8 @@ class Instance(SchemaBase): last_annotation_token (str): Foreign key pointing to the last annotation of this instance. """ - token: str category_token: str instance_name: str nbr_annotations: int first_annotation_token: str last_annotation_token: str - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - return cls(**data) diff --git a/t4_devkit/schema/tables/keypoint.py b/t4_devkit/schema/tables/keypoint.py index c63580e..6a92b60 100644 --- a/t4_devkit/schema/tables/keypoint.py +++ b/t4_devkit/schema/tables/keypoint.py @@ -1,22 +1,22 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -import numpy as np -from typing_extensions import Self +from attrs import define, field +from t4_devkit.common.converter import as_array + +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName if TYPE_CHECKING: from t4_devkit.typing import KeypointType -__all__ = ("Keypoint",) +__all__ = ["Keypoint"] -@dataclass +@define @SCHEMAS.register(SchemaName.KEYPOINT) class Keypoint(SchemaBase): """A dataclass to represent schema table of `keypoint.json`. @@ -30,27 +30,8 @@ class Keypoint(SchemaBase): num_keypoints (int): The number of keypoints to be annotated. """ - token: str sample_data_token: str instance_token: str category_tokens: list[str] - keypoints: KeypointType + keypoints: KeypointType = field(converter=as_array) num_keypoints: int - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - token: str = data["token"] - sample_data_token: str = data["sample_data_token"] - instance_token: str = data["instance_token"] - category_tokens: list[str] = data["category_tokens"] - keypoints = np.array(data["keypoints"]) - num_keypoints: int = data["num_keypoints"] - - return cls( - token=token, - sample_data_token=sample_data_token, - instance_token=instance_token, - category_tokens=category_tokens, - keypoints=keypoints, - num_keypoints=num_keypoints, - ) diff --git a/t4_devkit/schema/tables/log.py b/t4_devkit/schema/tables/log.py index 5360ca0..8ed5333 100644 --- a/t4_devkit/schema/tables/log.py +++ b/t4_devkit/schema/tables/log.py @@ -1,18 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - -from typing_extensions import Self +from attrs import define, field +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName __all__ = ("Log",) -@dataclass +@define @SCHEMAS.register(SchemaName.LOG) class Log(SchemaBase): """A dataclass to represent schema table of `log.json`. @@ -23,9 +20,12 @@ class Log(SchemaBase): vehicle (str): Vehicle name. data_captured (str): Date of the data was captured (YYYY-MM-DD-HH-mm-ss). location (str): Area where log was captured. + + Shortcuts: + map_token (str): Foreign key pointing to the map record. + This should be set after instantiated. """ - token: str logfile: str vehicle: str data_captured: str @@ -37,7 +37,3 @@ class Log(SchemaBase): @staticmethod def shortcuts() -> tuple[str]: return ("map_token",) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - return cls(**data) diff --git a/t4_devkit/schema/tables/map.py b/t4_devkit/schema/tables/map.py index d6428e4..e9da161 100644 --- a/t4_devkit/schema/tables/map.py +++ b/t4_devkit/schema/tables/map.py @@ -1,18 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any - -from typing_extensions import Self +from attrs import define +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName -__all__ = ("Map",) +__all__ = ["Map"] -@dataclass +@define @SCHEMAS.register(SchemaName.MAP) class Map(SchemaBase): """A dataclass to represent schema table of `map.json`. @@ -24,11 +21,6 @@ class Map(SchemaBase): filename (str): Relative path to the file with the map mask. """ - token: str log_tokens: list[str] category: str filename: str - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - return cls(**data) diff --git a/t4_devkit/schema/tables/object_ann.py b/t4_devkit/schema/tables/object_ann.py index 19557fe..a546ab9 100644 --- a/t4_devkit/schema/tables/object_ann.py +++ b/t4_devkit/schema/tables/object_ann.py @@ -1,11 +1,10 @@ from __future__ import annotations import base64 -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING +from attrs import define, field from pycocotools import mask as cocomask -from typing_extensions import Self from ..name import SchemaName from .base import SchemaBase @@ -14,10 +13,10 @@ if TYPE_CHECKING: from t4_devkit.typing import NDArrayU8, RoiType -__all__ = ("ObjectAnn", "RLEMask") +__all__ = ["ObjectAnn", "RLEMask"] -@dataclass +@define class RLEMask: """A dataclass to represent segmentation mask compressed by RLE. @@ -48,7 +47,7 @@ def decode(self) -> NDArrayU8: return cocomask.decode(data).T -@dataclass +@define @SCHEMAS.register(SchemaName.OBJECT_ANN) class ObjectAnn(SchemaBase): """A dataclass to represent schema table of `object_ann.json`. @@ -61,15 +60,17 @@ class ObjectAnn(SchemaBase): attribute_tokens (list[str]): Foreign keys. List of attributes for this annotation. bbox (RoiType): Annotated bounding box. Given as [xmin, ymin, xmax, ymax]. mask (RLEMask): Instance mask using the COCO format compressed by RLE. + + Shortcuts: + category_name (str): Category name. This should be set after instantiated. """ - token: str sample_data_token: str instance_token: str category_token: str attribute_tokens: list[str] - bbox: RoiType - mask: RLEMask + bbox: RoiType = field(converter=tuple) + mask: RLEMask = field(converter=lambda x: RLEMask(**x) if isinstance(x, dict) else x) # shortcuts category_name: str = field(init=False) @@ -78,12 +79,6 @@ class ObjectAnn(SchemaBase): def shortcuts() -> tuple[str]: return ("category_name",) - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - new_data = data.copy() - new_data["mask"] = RLEMask(**data["mask"]) - return cls(**new_data) - @property def width(self) -> int: """Return the width of the bounding box. diff --git a/t4_devkit/schema/tables/registry.py b/t4_devkit/schema/tables/registry.py index 16af39d..2ac1066 100644 --- a/t4_devkit/schema/tables/registry.py +++ b/t4_devkit/schema/tables/registry.py @@ -8,7 +8,7 @@ if TYPE_CHECKING: from .base import SchemaTable -__all__ = ("SCHEMAS",) +__all__ = ["SCHEMAS"] class SchemaRegistry: @@ -74,6 +74,7 @@ def build_from_json(self, key: str | SchemaName, filepath: str) -> list[SchemaTa key = key.value schema: SchemaTable = self.get(key) + print(schema) return schema.from_json(filepath) diff --git a/t4_devkit/schema/tables/sample.py b/t4_devkit/schema/tables/sample.py index dee0e0c..3c8f5f4 100644 --- a/t4_devkit/schema/tables/sample.py +++ b/t4_devkit/schema/tables/sample.py @@ -1,18 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Any - -from typing_extensions import Self +from attrs import define, field from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -__all__ = ("Sample",) +__all__ = ["Sample"] -@dataclass +@define @SCHEMAS.register(SchemaName.SAMPLE) class Sample(SchemaBase): """A dataclass to represent schema table of `sample.json`. @@ -36,22 +33,17 @@ class Sample(SchemaBase): This should be set after instantiated. """ - token: str timestamp: int scene_token: str next: str # noqa: A003 prev: str # shortcuts - data: dict[str, str] = field(default_factory=dict, init=False) - ann_3ds: list[str] = field(default_factory=list, init=False) - ann_2ds: list[str] = field(default_factory=list, init=False) - surface_anns: list[str] = field(default_factory=list, init=False) + data: dict[str, str] = field(factory=dict, init=False) + ann_3ds: list[str] = field(factory=list, init=False) + ann_2ds: list[str] = field(factory=list, init=False) + surface_anns: list[str] = field(factory=list, init=False) @staticmethod def shortcuts() -> tuple[str, str, str, str]: return ("data", "ann_3ds", "ann_2ds", "surface_ann_2ds") - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - return cls(**data) diff --git a/t4_devkit/schema/tables/sample_annotation.py b/t4_devkit/schema/tables/sample_annotation.py index 623bb6a..da3c468 100644 --- a/t4_devkit/schema/tables/sample_annotation.py +++ b/t4_devkit/schema/tables/sample_annotation.py @@ -1,15 +1,14 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING -import numpy as np -from pyquaternion import Quaternion -from typing_extensions import Self +from attrs import define, field +from t4_devkit.common.converter import as_array, as_quaternion + +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName if TYPE_CHECKING: from t4_devkit.typing import ( @@ -20,10 +19,10 @@ VelocityType, ) -__all__ = ("SampleAnnotation",) +__all__ = ["SampleAnnotation"] -@dataclass +@define @SCHEMAS.register(SchemaName.SAMPLE_ANNOTATION) class SampleAnnotation(SchemaBase): """A dataclass to represent schema table of `sample_annotation.json`. @@ -53,20 +52,19 @@ class SampleAnnotation(SchemaBase): category_name (str): Category name. This should be set after instantiated. """ - token: str sample_token: str instance_token: str attribute_tokens: list[str] visibility_token: str - translation: TranslationType - size: SizeType - rotation: RotationType + translation: TranslationType = field(converter=as_array) + size: SizeType = field(converter=as_array) + rotation: RotationType = field(converter=as_quaternion) num_lidar_pts: int num_radar_pts: int next: str # noqa: A003 prev: str - velocity: VelocityType | None = field(default=None) - acceleration: AccelerationType | None = field(default=None) + velocity: VelocityType | None = field(default=None, converter=as_array) + acceleration: AccelerationType | None = field(default=None, converter=as_array) # shortcuts category_name: str = field(init=False) @@ -74,37 +72,3 @@ class SampleAnnotation(SchemaBase): @staticmethod def shortcuts() -> tuple[str]: return ("category_name",) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - token: str = data["token"] - sample_token: str = data["sample_token"] - instance_token: str = data["instance_token"] - attribute_tokens: list[str] = data["attribute_tokens"] - visibility_token: str = data["visibility_token"] - translation = np.array(data["translation"]) - velocity = np.array(data["velocity"]) if data.get("velocity") else None - acceleration = np.array(data["acceleration"]) if data.get("acceleration") else None - size = np.array(data["size"]) - rotation = Quaternion(data["rotation"]) - num_lidar_pts: int = data["num_lidar_pts"] - num_radar_pts: int = data["num_radar_pts"] - next_: str = data["next"] - prev: str = data["prev"] - - return cls( - token=token, - sample_token=sample_token, - instance_token=instance_token, - attribute_tokens=attribute_tokens, - visibility_token=visibility_token, - translation=translation, - velocity=velocity, - acceleration=acceleration, - size=size, - rotation=rotation, - num_lidar_pts=num_lidar_pts, - num_radar_pts=num_radar_pts, - next=next_, - prev=prev, - ) diff --git a/t4_devkit/schema/tables/sample_data.py b/t4_devkit/schema/tables/sample_data.py index d45d2d3..8949137 100644 --- a/t4_devkit/schema/tables/sample_data.py +++ b/t4_devkit/schema/tables/sample_data.py @@ -1,31 +1,21 @@ from __future__ import annotations -from dataclasses import dataclass, field -import sys -from typing import TYPE_CHECKING, Any +from enum import Enum +from typing import TYPE_CHECKING -from typing_extensions import Self +from attrs import define, field +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName - -if sys.version_info < (3, 11): - from enum import Enum - - class StrEnum(str, Enum): - pass - -else: - from enum import StrEnum if TYPE_CHECKING: from .sensor import SensorModality -__all__ = ("SampleData", "FileFormat") +__all__ = ["SampleData", "FileFormat"] -class FileFormat(StrEnum): +class FileFormat(str, Enum): """An enum to represent file formats. Attributes: @@ -72,7 +62,7 @@ def as_ext(self) -> str: return f".{self.value}" -@dataclass +@define @SCHEMAS.register(SchemaName.SAMPLE_DATA) class SampleData(SchemaBase): """A class to represent schema table of `sample_data.json`. @@ -100,19 +90,18 @@ class SampleData(SchemaBase): channel (str): Sensor channel. This should be set after instantiated. """ - token: str sample_token: str ego_pose_token: str calibrated_sensor_token: str filename: str - fileformat: FileFormat + fileformat: FileFormat = field(converter=FileFormat) width: int height: int timestamp: int is_key_frame: bool next: str # noqa: A003 prev: str - is_valid: bool + is_valid: bool = field(default=True) # shortcuts modality: SensorModality = field(init=False) @@ -121,35 +110,3 @@ class SampleData(SchemaBase): @staticmethod def shortcuts() -> tuple[str, str]: return ("modality", "channel") - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - token: str = data["token"] - sample_token: str = data["sample_token"] - ego_pose_token: str = data["ego_pose_token"] - calibrated_sensor_token: str = data["calibrated_sensor_token"] - filename: str = data["filename"] - fileformat = FileFormat(data["fileformat"]) - width: int = data["width"] - height: int = data["height"] - timestamp: int = data["timestamp"] - is_key_frame: bool = data["is_key_frame"] - next_: str = data["next"] - prev: str = data["prev"] - is_valid: bool = data.get("is_valid", True) - - return cls( - token=token, - sample_token=sample_token, - ego_pose_token=ego_pose_token, - calibrated_sensor_token=calibrated_sensor_token, - filename=filename, - fileformat=fileformat, - width=width, - height=height, - timestamp=timestamp, - is_key_frame=is_key_frame, - next=next_, - prev=prev, - is_valid=is_valid, - ) diff --git a/t4_devkit/schema/tables/scene.py b/t4_devkit/schema/tables/scene.py index 17ca896..461d8ef 100644 --- a/t4_devkit/schema/tables/scene.py +++ b/t4_devkit/schema/tables/scene.py @@ -1,18 +1,15 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import Any - -from typing_extensions import Self +from attrs import define +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName -__all__ = ("Scene",) +__all__ = ["Scene"] -@dataclass +@define @SCHEMAS.register(SchemaName.SCENE) class Scene(SchemaBase): """A dataclass to represent schema table of `scene.json`. @@ -27,14 +24,9 @@ class Scene(SchemaBase): last_sample_token (str): Foreign key pointing to the last sample in scene. """ - token: str name: str description: str log_token: str nbr_samples: int first_sample_token: str last_sample_token: str - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - return cls(**data) diff --git a/t4_devkit/schema/tables/sensor.py b/t4_devkit/schema/tables/sensor.py index 814ae30..99f5341 100644 --- a/t4_devkit/schema/tables/sensor.py +++ b/t4_devkit/schema/tables/sensor.py @@ -1,29 +1,17 @@ from __future__ import annotations -from dataclasses import dataclass, field from enum import Enum -import sys -from typing import Any -from typing_extensions import Self +from attrs import define, field +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName - -if sys.version_info < (3, 11): - - class StrEnum(str, Enum): - pass -else: - from enum import StrEnum +__all__ = ["Sensor", "SensorModality"] -__all__ = ("Sensor", "SensorModality") - - -class SensorModality(StrEnum): +class SensorModality(str, Enum): """An enum to represent sensor modalities. Attributes: @@ -37,7 +25,7 @@ class SensorModality(StrEnum): RADAR = "radar" -@dataclass +@define @SCHEMAS.register(SchemaName.SENSOR) class Sensor(SchemaBase): """A dataclass to represent schema table of `sensor.json`. @@ -52,9 +40,8 @@ class Sensor(SchemaBase): first_sd_token (str): The first sample data token corresponding to its sensor channel. """ - token: str channel: str - modality: SensorModality + modality: SensorModality = field(converter=SensorModality) # shortcuts first_sd_token: str = field(init=False) @@ -62,11 +49,3 @@ class Sensor(SchemaBase): @staticmethod def shortcuts() -> tuple[str] | None: return ("first_sd_token",) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - token: str = data["token"] - channel = data["channel"] - modality = SensorModality(data["modality"]) - - return cls(token=token, channel=channel, modality=modality) diff --git a/t4_devkit/schema/tables/surface_ann.py b/t4_devkit/schema/tables/surface_ann.py index ce83626..b2f6813 100644 --- a/t4_devkit/schema/tables/surface_ann.py +++ b/t4_devkit/schema/tables/surface_ann.py @@ -1,10 +1,9 @@ from __future__ import annotations -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import numpy as np -from typing_extensions import Self +from attrs import define, field from ..name import SchemaName from .base import SchemaBase @@ -14,10 +13,10 @@ if TYPE_CHECKING: from t4_devkit.typing import RoiType -__all__ = ("SurfaceAnn",) +__all__ = ["SurfaceAnn"] -@dataclass +@define @SCHEMAS.register(SchemaName.SURFACE_ANN) class SurfaceAnn(SchemaBase): """A dataclass to represent schema table of `surface_ann.json`. @@ -29,10 +28,9 @@ class SurfaceAnn(SchemaBase): mask (RLEMask): Segmentation mask using the COCO format compressed by RLE. """ - token: str sample_data_token: str category_token: str - mask: RLEMask + mask: RLEMask = field(converter=lambda x: RLEMask(**x) if isinstance(x, dict) else x) # shortcuts category_name: str = field(init=False) @@ -41,12 +39,6 @@ class SurfaceAnn(SchemaBase): def shortcuts() -> tuple[str]: return ("category_name",) - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - new_data = data.copy() - new_data["mask"] = RLEMask(**data["mask"]) - return cls(**new_data) - @property def bbox(self) -> RoiType: """Return a bounding box corners calculated from polygon vertices. diff --git a/t4_devkit/schema/tables/visibility.py b/t4_devkit/schema/tables/visibility.py index ae87cb1..83a6c5f 100644 --- a/t4_devkit/schema/tables/visibility.py +++ b/t4_devkit/schema/tables/visibility.py @@ -1,29 +1,19 @@ from __future__ import annotations -from dataclasses import dataclass -import sys -from typing import Any import warnings +from enum import Enum +from attrs import define, field from typing_extensions import Self +from ..name import SchemaName from .base import SchemaBase from .registry import SCHEMAS -from ..name import SchemaName - -if sys.version_info < (3, 11): - from enum import Enum - - class StrEnum(str, Enum): - pass - -else: - from enum import StrEnum __all__ = ("Visibility", "VisibilityLevel") -class VisibilityLevel(StrEnum): +class VisibilityLevel(str, Enum): """An enum to represent visibility levels. Attributes: @@ -69,7 +59,7 @@ def _from_alias(level: str) -> Self: return VisibilityLevel.UNAVAILABLE -@dataclass +@define @SCHEMAS.register(SchemaName.VISIBILITY) class Visibility(SchemaBase): """A dataclass to represent schema table of `visibility.json`. @@ -80,14 +70,9 @@ class Visibility(SchemaBase): description (str): Description of visibility level. """ - token: str - level: VisibilityLevel + level: VisibilityLevel = field( + converter=lambda x: VisibilityLevel.from_value(x) + if not isinstance(x, VisibilityLevel) + else VisibilityLevel(x) + ) description: str - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> Self: - token: str = data["token"] - level = VisibilityLevel.from_value(data["level"]) - description: str = data["description"] - - return cls(token=token, level=level, description=description)