Skip to content

Commit

Permalink
Support deserializing missing fields that have defaults (#9)
Browse files Browse the repository at this point in the history
Adds support in `deserialize` for missing fields that are defined with default values.

If a `NamedTuple` or `@dataclass` deriving type's field is set with a default value
(i.e. `seconds: int = 10`), then, it is ok for a raw value passed in to `deserialize` to not
have this field. Upon encountering this situation, the field's default value will be used.

The existing prior behavior was to raise a `MissingRequired` exception.
This was confusing as it was technically an incorrect error to raise.

New test cases validate this default-respecting behavior during deserialization.

Version `0.4.0`
  • Loading branch information
malcolmgreaves committed Mar 9, 2022
1 parent 5d0f77d commit 8becdab
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .python-version
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
3.8.3
3.8.5
3.7.7
34 changes: 26 additions & 8 deletions core_utils/schema.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from typing import Type, Iterable, Union, Any, Mapping, Sequence, get_args, cast
from typing import (
Type,
Iterable,
Union,
Any,
Mapping,
Sequence,
get_args,
cast,
Callable,
Tuple,
)
from dataclasses import is_dataclass

from core_utils.common import type_name, checkable_type
from core_utils.serialization import (
is_typed_namedtuple,
_namedtuple_field_types,
_dataclass_field_types,
_dataclass_field_types_defaults,
)


Expand Down Expand Up @@ -37,12 +48,19 @@ def dict_type_representation(nt_or_dc_type: Type) -> Discover:

def _dict_type(t: type):
if is_typed_namedtuple(t) or is_dataclass(t):
field_types_of = (
_namedtuple_field_types
if is_typed_namedtuple(t)
else _dataclass_field_types
)
accum = {name: _dict_type(field_type) for name, field_type in field_types_of(t)}
if is_typed_namedtuple(t):
field_types_of: Callable[
[], Iterable[Tuple[str, type]]
] = lambda: _namedtuple_field_types(
t # type: ignore
)
else:

def field_types_of() -> Iterable[Tuple[str, type]]:
for a, b, _ in _dataclass_field_types_defaults(t):
yield a, b

accum = {name: _dict_type(field_type) for name, field_type in field_types_of()}
return accum

else:
Expand Down
74 changes: 59 additions & 15 deletions core_utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import traceback
from enum import Enum
from typing import (
Any,
Expand All @@ -15,8 +16,10 @@
get_args,
Union,
cast,
List,
Dict,
)
from dataclasses import dataclass, is_dataclass, Field
from dataclasses import dataclass, is_dataclass, Field, _MISSING_TYPE

from core_utils.common import type_name, checkable_type, split_module_value

Expand Down Expand Up @@ -281,6 +284,7 @@ def _namedtuple_from_dict(
_namedtuple_field_types(namedtuple_type),
data,
namedtuple_type,
_namedtuple_field_defaults(namedtuple_type),
custom,
)
)
Expand Down Expand Up @@ -309,6 +313,15 @@ def _namedtuple_field_types(
return namedtuple_type._field_types.items() # type: ignore


def _namedtuple_field_defaults(
namedtuple_type: Type[SomeNamedTuple],
) -> Mapping[str, Any]:
return {
k: serialize(v, no_none_values=False)
for k, v in namedtuple_type._field_defaults.items() # type: ignore
}


def _dataclass_from_dict(
dataclass_type: Type, data: dict, custom: Optional[CustomFormat],
) -> Any:
Expand All @@ -319,7 +332,9 @@ def _dataclass_from_dict(
)
if is_dataclass(dataclass_type) or is_generic_dataclass:
try:
field_and_types = list(_dataclass_field_types(dataclass_type))
all_field_type_default: List[Tuple[str, Type, Optional[Any]]] = list(
_dataclass_field_types_defaults(dataclass_type)
)
except AttributeError as ae:
raise TypeError(
"Did you pass-in a type that is decorated with @dataclass? "
Expand All @@ -329,8 +344,15 @@ def _dataclass_from_dict(
ae,
)

field_and_types: List[Tuple[str, Type]] = []
field_defaults: Dict[str, Any] = dict()
for field, _type, maybe_default in all_field_type_default:
field_and_types.append((field, _type))
if maybe_default is not None:
field_defaults[field] = maybe_default

deserialized_fields = _values_for_type(
field_and_types, data, dataclass_type, custom
field_and_types, data, dataclass_type, field_defaults, custom
)
deserialized_fields = list(deserialized_fields)
field_values = dict(
Expand All @@ -345,30 +367,40 @@ def _dataclass_from_dict(
)


def _dataclass_field_types(dataclass_type: Type) -> Iterable[Tuple[str, Type]]:
def _dataclass_field_types_defaults(
dataclass_type: Type,
) -> Iterable[Tuple[str, Type, Optional[Any]]]:
"""Obtain the fields & their expected types for the given @dataclass type.
"""
if hasattr(dataclass_type, "__origin__"):
dataclass_fields = dataclass_type.__origin__.__dataclass_fields__ # type: ignore
generic_to_concrete = dict(_align_generic_concrete(dataclass_type))

def as_name_and_type(data_field: Field) -> Tuple[str, Type]:
def handle_field(data_field: Field) -> Tuple[str, Type, Optional[Any]]:
if data_field.type in generic_to_concrete:
typ = generic_to_concrete[data_field.type]
elif hasattr(data_field.type, "__parameters__"):
tn = _fill(generic_to_concrete, data_field.type)
typ = _exec(data_field.type.__origin__, tn)
else:
typ = data_field.type
return data_field.name, typ
return data_field.name, typ, _default_of(data_field)

else:
dataclass_fields = dataclass_type.__dataclass_fields__ # type: ignore

def as_name_and_type(data_field: Field) -> Tuple[str, Type]:
return data_field.name, data_field.type
def handle_field(data_field: Field) -> Tuple[str, Type, Optional[Any]]:
return data_field.name, data_field.type, _default_of(data_field)

return list(map(handle_field, dataclass_fields.values()))

return list(map(as_name_and_type, dataclass_fields.values()))

def _default_of(data_field: Field) -> Optional[Any]:
return (
None
if isinstance(data_field.default, _MISSING_TYPE)
else serialize(data_field.default, no_none_values=False)
)


def _align_generic_concrete(
Expand Down Expand Up @@ -453,8 +485,9 @@ def _exec(origin_type, tn):

def _values_for_type(
field_name_expected_type: Iterable[Tuple[str, Type]],
data: dict,
data: Mapping[str, Any],
type_data: Type,
field_to_default: Mapping[str, Any],
custom: Optional[CustomFormat],
) -> Iterable:
"""Constructs an instance of :param:`type_data` using the data in :param:`data`, with
Expand Down Expand Up @@ -497,6 +530,10 @@ def _values_for_type(
actual_value=value,
)

elif field_name in field_to_default:
value = field_to_default[field_name]
# use a default value for an Optional field
# before defaulting to the None value
elif _is_optional(field_type): # type: ignore
value = None
else:
Expand All @@ -511,12 +548,19 @@ def _values_for_type(
yield deserialize(field_type, value, custom) # type: ignore
else:
yield None
except (FieldDeserializeFail, MissingRequired):
raise
except Exception as e:
raise FieldDeserializeFail(
field_name=field_name, expected_type=field_type, actual_value=value
) from e
print(
"ERROR deserializing field:'"
+ str(field_name)
+ "'\n"
+ traceback.format_exc()
)
if isinstance(e, (FieldDeserializeFail, MissingRequired)):
raise e
else:
raise FieldDeserializeFail(
field_name=field_name, expected_type=field_type, actual_value=value
) from e


@dataclass(frozen=True)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pywise"
version = "0.3.2"
version = "0.4.0"
description = "Robust serialization support for NamedTuple & @dataclass data types."
authors = ["Malcolm Greaves <[email protected]>"]
homepage = "https://github.com/malcolmgreaves/pywise"
Expand Down
87 changes: 87 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,3 +493,90 @@ def test_serialize_none_special_cases_mapping():
assert len(s) == 1
assert deserialize(Mapping[str, Optional[int]], s) == m_empty
assert deserialize(Mapping[str, Optional[int]], serialize(m_empty)) == {}


@dataclass(frozen=True)
class HasDefaultsDC:
value: int = 10


class HasDefaultsNT(NamedTuple):
name: str = "<noname>"


def test_serialize_has_defaults_dc():
x: HasDefaultsDC = deserialize(HasDefaultsDC, {})
assert x.value == 10
s = serialize(x)
assert s == {"value": 10}
assert serialize(deserialize(HasDefaultsDC, serialize(HasDefaultsDC()))) == s


def test_serialize_has_defaults_nt():
x: HasDefaultsNT = deserialize(HasDefaultsNT, {})
assert x.name == "<noname>"
s = serialize(x)
assert s == {"name": "<noname>"}
assert serialize(deserialize(HasDefaultsNT, serialize(HasDefaultsNT()))) == s


class Next3(NamedTuple):
dc_value: HasDefaultsDC = HasDefaultsDC()
nt_name: HasDefaultsNT = HasDefaultsNT()


class Next2(NamedTuple):
dc_value: HasDefaultsDC = HasDefaultsDC()
nt_name: HasDefaultsNT = HasDefaultsNT()
next: Optional[Next3] = None


class Next1(NamedTuple):
dc_value: HasDefaultsDC = HasDefaultsDC()
nt_name: HasDefaultsNT = HasDefaultsNT()
next: Optional[Next2] = None


@dataclass(frozen=True)
class NestedDefaultsMixed:
dc_value: HasDefaultsDC = HasDefaultsDC()
nt_name: HasDefaultsNT = HasDefaultsNT()
next: Optional[Next1] = None


def test_serialized_nested_defaults_basic():
x: NestedDefaultsMixed = deserialize(NestedDefaultsMixed, {})
assert x.dc_value.value == 10
assert x.nt_name.name == "<noname>"
assert x.next is None

s = serialize(x, no_none_values=False)
assert s == {
"dc_value": {"value": 10},
"nt_name": {"name": "<noname>"},
"next": None,
}
assert (
serialize(
deserialize(NestedDefaultsMixed, serialize(NestedDefaultsMixed())),
no_none_values=False,
)
== s
)


def test_serialized_nested_defaults_advanced():
nested = NestedDefaultsMixed(
dc_value=HasDefaultsDC(9999),
nt_name=HasDefaultsNT("powerlevel"),
next=Next1(
nt_name=HasDefaultsNT("hello world!"),
next=Next2(
dc_value=HasDefaultsDC(-50),
next=Next3(nt_name=HasDefaultsNT("goodbye universe?"),),
),
),
)

s = serialize(nested)
assert nested == deserialize(NestedDefaultsMixed, s)

0 comments on commit 8becdab

Please sign in to comment.