diff --git a/t4_devkit/schema/serialize.py b/t4_devkit/schema/serialize.py index e8853af..0e3c79d 100644 --- a/t4_devkit/schema/serialize.py +++ b/t4_devkit/schema/serialize.py @@ -1,11 +1,10 @@ from __future__ import annotations from enum import Enum -from functools import partial -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any import numpy as np -from attrs import asdict +from attrs import asdict, filters from pyquaternion import Quaternion if TYPE_CHECKING: @@ -36,34 +35,15 @@ def serialize_schema(data: SchemaTable) -> dict: Returns: Serialized dict data. """ - dict_factory = partial(_schema_as_dict_factory, excludes=data.shortcuts()) - return asdict(data, dict_factory=dict_factory) - - -def _schema_as_dict_factory( - data: list[tuple[str, Any]], *, excludes: Sequence[str] | None = None -) -> dict: - """A factory to convert schema dataclass field to dict data. - - Args: - data (list[tuple[str, Any]]): Some data of dataclass field. - excludes (Sequence[str] | None, optional): Sequence of field names to be excluded. - - Returns: - Converted dict data. - """ - - def _convert_value(value: Any) -> Any: - if isinstance(value, np.ndarray): - return value.tolist() - elif isinstance(value, Quaternion): - return value.q.tolist() - elif isinstance(value, Enum): - return value.value - return value - - return ( - {k: _convert_value(v) for k, v in data} - if excludes is None - else {k: _convert_value(v) for k, v in data if k not in excludes} - ) + excludes = filters.exclude(*data.shortcuts()) if data.shortcuts() is not None else None + return asdict(data, filter=excludes, value_serializer=_value_serializer) + + +def _value_serializer(data: SchemaTable, attr: Any, value: Any) -> Any: + if isinstance(value, np.ndarray): + return value.tolist() + elif isinstance(value, Quaternion): + return value.q.tolist() + elif isinstance(value, Enum): + return value.value + return value