Skip to content

Commit

Permalink
Propagate CustomFormat in serialize & deserialize (#2)
Browse files Browse the repository at this point in the history
Bugfix - propagate the `CustomFormat` supplied in a `serialize` & `deserialize` calls. Updated the internals in `serialization.py` to ensure that the supplied `CustomFormat` type -> (de)serializer mapping is used when the routine(s) encounter a `Mapping`, `Iterable`, `NamedTuple` or `@dataclass`-deriving input. New tests in `test_custom_serialization` cover these cases.

Version `0.1.1`
  • Loading branch information
malcolmgreaves authored Aug 11, 2020
1 parent 4493c08 commit 9d7b5df
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 33 deletions.
47 changes: 30 additions & 17 deletions core_utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ def serialize(value: Any, custom: Optional[CustomFormat] = None) -> Any:
return custom[type(value)](value)

elif is_namedtuple(value):
return {k: serialize(raw_val) for k, raw_val in value._asdict().items()}
return {k: serialize(raw_val, custom) for k, raw_val in value._asdict().items()}

elif is_dataclass(value):
return {k: serialize(v) for k, v in value.__dict__.items()}
return {k: serialize(v, custom) for k, v in value.__dict__.items()}

elif isinstance(value, Mapping):
return {serialize(k): serialize(v) for k, v in value.items()}
return {serialize(k, custom): serialize(v, custom) for k, v in value.items()}

elif isinstance(value, Iterable) and not isinstance(value, str):
return list(map(serialize, value))
return list(map(lambda x: serialize(x, custom), value))

elif isinstance(value, Enum):
# serialize the enum value's name as it's a better identifier than the
Expand Down Expand Up @@ -109,18 +109,18 @@ def deserialize(
checking_type_value: Type = checkable_type(type_value)

if is_namedtuple(checking_type_value):
return _namedtuple_from_dict(type_value, value)
return _namedtuple_from_dict(type_value, value, custom)

elif is_dataclass(checking_type_value):
return _dataclass_from_dict(type_value, value)
return _dataclass_from_dict(type_value, value, custom)

# NOTE: Need to have type_value instead of checking_type_value here !
elif _is_optional(type_value):
# obtain generic parameter \& deserialize
if value is None:
return None
else:
return deserialize(type_value.__args__[0], value)
return deserialize(type_value.__args__[0], value, custom)

# NOTE: Need to have type_value instead of checking_type_value here !
elif _is_union(type_value):
Expand All @@ -129,7 +129,7 @@ def deserialize(
# possible types
print(possible_type)
try:
return deserialize(possible_type, value)
return deserialize(possible_type, value, custom)
except Exception:
pass
raise FieldDeserializeFail(
Expand All @@ -139,20 +139,23 @@ def deserialize(
elif issubclass(checking_type_value, Mapping):
k_type, v_type = type_value.__args__ # type: ignore
return {
deserialize(k_type, k): deserialize(v_type, v) for k, v in value.items()
deserialize(k_type, k, custom): deserialize(v_type, v, custom)
for k, v in value.items()
}

elif issubclass(checking_type_value, Tuple) and checking_type_value != str: # type: ignore
tuple_type_args = type_value.__args__
converted = map(
lambda type_val_pair: deserialize(type_val_pair[0], type_val_pair[1]),
lambda type_val_pair: deserialize(
type_val_pair[0], type_val_pair[1], custom
),
zip(tuple_type_args, value),
)
return tuple(converted)

elif issubclass(checking_type_value, Iterable) and checking_type_value != str:
(i_type,) = type_value.__args__ # type: ignore
converted = map(lambda x: deserialize(i_type, x), value)
converted = map(lambda x: deserialize(i_type, x, custom), value)
if issubclass(checking_type_value, Set):
return set(converted)
else:
Expand Down Expand Up @@ -220,7 +223,9 @@ def is_typed_namedtuple(x: Any) -> bool:


def _namedtuple_from_dict(
namedtuple_type: Type[SomeNamedTuple], data: dict
namedtuple_type: Type[SomeNamedTuple],
data: dict,
custom: Optional[CustomFormat] = None,
) -> SomeNamedTuple:
"""Initializes an instance of the given namedtuple type from a dict of its names and values.
Expand All @@ -231,7 +236,10 @@ def _namedtuple_from_dict(
try:
field_values = tuple(
_values_for_type(
_namedtuple_field_types(namedtuple_type), data, namedtuple_type
_namedtuple_field_types(namedtuple_type),
data,
namedtuple_type,
custom,
)
)
return namedtuple_type._make(field_values) # type: ignore
Expand Down Expand Up @@ -259,14 +267,16 @@ def _namedtuple_field_types(
return namedtuple_type._field_types.items() # type: ignore


def _dataclass_from_dict(dataclass_type: Type, data: dict) -> Any:
def _dataclass_from_dict(
dataclass_type: Type, data: dict, custom: Optional[CustomFormat] = None
) -> Any:
"""Constructs an @dataclass instance using :param:`data`.
"""
if is_dataclass(dataclass_type):
try:
field_and_types = list(_dataclass_field_types(dataclass_type))
deserialized_fields = _values_for_type(
field_and_types, data, dataclass_type
field_and_types, data, dataclass_type, custom
)
field_values = dict(
zip(map(lambda x: x[0], field_and_types), deserialized_fields)
Expand Down Expand Up @@ -299,7 +309,10 @@ def as_name_and_type(data_field: Field) -> Tuple[str, Type]:


def _values_for_type(
field_name_expected_type: Iterable[Tuple[str, Type]], data: dict, type_data: Type,
field_name_expected_type: Iterable[Tuple[str, Type]],
data: dict,
type_data: Type,
custom: Optional[CustomFormat] = None,
) -> Iterable:
"""Constructs an instance of :param:`type_data` using the data in :param:`data`, with
field names & expected types of :param:`field_name_expected_type` guiding construction.
Expand Down Expand Up @@ -352,7 +365,7 @@ def _values_for_type(

try:
if value is not None:
yield deserialize(type_value=field_type, value=value) # type: ignore
yield deserialize(field_type, value, custom) # type: ignore
else:
yield None
except (FieldDeserializeFail, MissingRequired):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pywise"
version = "0.1.0"
version = "0.1.1"
description = "Robust serialization support for NamedTuple & @dataclass data types."
authors = ["Malcolm Greaves <[email protected]>"]
homepage = "https://github.com/malcolmgreaves/pywise"
Expand Down
83 changes: 68 additions & 15 deletions tests/test_custom_serialization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from typing import Tuple
from dataclasses import dataclass
from typing import Tuple, NamedTuple, Mapping, Sequence

import numpy as np
import torch
Expand Down Expand Up @@ -42,20 +43,7 @@ def multi_dim_shape() -> Tuple[int, int, int, int, int]:


def _test_procedue(cs, cd, simple_len, n_rando_times, multi_shape, make):
def check(*, actual, expected):
assert isinstance(
actual, type(expected)
), f"Expecting {type(expected)} recieved {type(actual)}"
assert (expected == actual).all() # type: ignore

def roundtrip(a):
s1 = serialize(a, custom=cs)
d1 = deserialize(type(a), s1, custom=cd)
check(actual=d1, expected=a)
j = json.dumps(s1)
s2 = json.loads(j)
d2 = deserialize(type(a), s2, custom=cd)
check(actual=d2, expected=a)
roundtrip = lambda a: _roundtrip(a, cs, cd, _check_array_like)

roundtrip(np.zeros(simple_len))

Expand All @@ -64,6 +52,23 @@ def roundtrip(a):
roundtrip(arr)


def _roundtrip(a, cs, cd, check):
s1 = serialize(a, custom=cs)
d1 = deserialize(type(a), s1, custom=cd)
check(actual=d1, expected=a)
j = json.dumps(s1)
s2 = json.loads(j)
d2 = deserialize(type(a), s2, custom=cd)
check(actual=d2, expected=a)


def _check_array_like(*, actual, expected):
assert isinstance(
actual, type(expected)
), f"Expecting {type(expected)} recieved {type(actual)}"
assert (expected == actual).all() # type: ignore


def test_serialization_numpy_array(
custom_serialize,
custom_deserialize,
Expand Down Expand Up @@ -96,3 +101,51 @@ def test_serialization_torch_tensor(
multi_dim_shape,
lambda s: torch.from_numpy(np.random.random(s)),
)


def test_custom_serialize_map(custom_serialize, custom_deserialize, multi_dim_shape):
class MNT(NamedTuple):
field: Mapping[str, np.ndarray]

@dataclass(frozen=True)
class MDC:
field: Mapping[str, np.ndarray]

mnt = MNT(field={"an_id": np.random.random(multi_dim_shape)})
mdc = MDC(field={"an_id": np.random.random(multi_dim_shape)})

def check(*, actual, expected):
assert isinstance(actual, type(expected))
assert "an_id" in actual.field
_check_array_like(
actual=actual.field["an_id"], expected=expected.field["an_id"]
)

_roundtrip(mnt, custom_serialize, custom_deserialize, check)
_roundtrip(mdc, custom_serialize, custom_deserialize, check)


def test_custom_serialize_iterable(
custom_serialize, custom_deserialize, multi_dim_shape
):
class MNT(NamedTuple):
field: Sequence[np.ndarray]

@dataclass(frozen=True)
class MDC:
field: Sequence[np.ndarray]

mnt = MNT(
field=[np.random.random(multi_dim_shape), np.random.random(multi_dim_shape)]
)
mdc = MDC(
field=[np.random.random(multi_dim_shape), np.random.random(multi_dim_shape)]
)

def check(*, actual, expected):
assert isinstance(actual, type(expected))
for a, e in zip(actual.field, expected.field):
_check_array_like(actual=a, expected=e)

_roundtrip(mnt, custom_serialize, custom_deserialize, check)
_roundtrip(mdc, custom_serialize, custom_deserialize, check)

0 comments on commit 9d7b5df

Please sign in to comment.