Skip to content

Commit

Permalink
More robust enum serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Nov 24, 2021
1 parent ec216b8 commit 7ff0e00
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ And, from `python simple.py --field1 string --field2 4`:
```
Args(field1='string', field2=4)
!Args
!dataclass:Args
field1: string
field2: 4
```
Expand Down
82 changes: 58 additions & 24 deletions dcargs/_serialization.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,38 @@
import dataclasses
import datetime
import enum
from typing import IO, Any, Optional, Set, Type, TypeVar, Union

import yaml
from typing_extensions import get_origin

from . import _resolver

DATACLASS_YAML_TAG_PREFIX = "!"
ENUM_YAML_TAG_PREFIX = "!enum:"
DATACLASS_YAML_TAG_PREFIX = "!dataclass:"

DataclassType = TypeVar("DataclassType")


def _get_contained_dataclass_types_from_instance(instance: Any) -> Set[Type]:
"""Takes a dataclass instance, and recursively searches its cihldren for dataclass
def _get_contained_special_types_from_instance(instance: Any) -> Set[Type]:
"""Takes an object and recursively searches its cihldren for dataclass or enum
types."""
if not dataclasses.is_dataclass(instance):
if issubclass(type(instance), enum.Enum):
return {type(instance)}
elif not dataclasses.is_dataclass(instance):
return set()

out = {type(instance)}
for v in vars(instance).values():
out |= _get_contained_dataclass_types_from_instance(v)
out |= _get_contained_special_types_from_instance(v)
return out


def _get_contained_dataclass_types_from_type(
def _get_contained_special_types_from_type(
cls: Type,
_parent_contained_dataclasses: Optional[Set[Type]] = None,
) -> Set[Type]:
"""Takes a dataclass type, and recursively searches its fields for dataclass
"""Takes a dataclass type, and recursively searches its fields for dataclass or enum
types."""
assert _resolver.is_dataclass(cls)
parent_contained_dataclasses = (
Expand All @@ -41,11 +47,13 @@ def _get_contained_dataclass_types_from_type(

def handle_type(typ) -> Set[Type]:
if _resolver.is_dataclass(typ) and typ not in parent_contained_dataclasses:
return _get_contained_dataclass_types_from_type(
return _get_contained_special_types_from_type(
typ,
_parent_contained_dataclasses=contained_dataclasses
| parent_contained_dataclasses,
)
elif type(typ) is enum.EnumMeta:
return {typ}
return set()

# Handle generics.
Expand Down Expand Up @@ -74,7 +82,7 @@ class DataclassLoader(yaml.Loader):
# => let's just keep things simple, assert uniqueness for now. Easier to add new
# features later than remove them.

contained_types = list(_get_contained_dataclass_types_from_type(cls))
contained_types = list(_get_contained_special_types_from_type(cls))
contained_type_names = list(map(lambda cls: cls.__name__, contained_types))
assert len(set(contained_type_names)) == len(
contained_type_names
Expand All @@ -83,14 +91,25 @@ class DataclassLoader(yaml.Loader):
loader: yaml.Loader
node: yaml.Node

def make_constructor(typ: Type):
def make_dataclass_constructor(typ: Type):
return lambda loader, node: typ(**loader.construct_mapping(node))

def make_enum_constructor(typ: Type):
return lambda loader, node: typ[loader.construct_python_str(node)]

for typ, name in zip(contained_types, contained_type_names):
DataclassLoader.add_constructor(
tag=DATACLASS_YAML_TAG_PREFIX + name,
constructor=make_constructor(typ),
)
if dataclasses.is_dataclass(typ):
DataclassLoader.add_constructor(
tag=DATACLASS_YAML_TAG_PREFIX + name,
constructor=make_dataclass_constructor(typ),
)
elif issubclass(typ, enum.Enum):
DataclassLoader.add_constructor(
tag=ENUM_YAML_TAG_PREFIX + name,
constructor=make_enum_constructor(typ),
)
else:
assert False

return DataclassLoader

Expand All @@ -99,7 +118,7 @@ def _make_dumper(instance: Any) -> Type[yaml.Dumper]:
class DataclassDumper(yaml.Dumper):
pass

contained_types = list(_get_contained_dataclass_types_from_instance(instance))
contained_types = list(_get_contained_special_types_from_instance(instance))
contained_type_names = list(map(lambda cls: cls.__name__, contained_types))
assert len(set(contained_type_names)) == len(
contained_type_names
Expand All @@ -110,14 +129,22 @@ class DataclassDumper(yaml.Dumper):
field: dataclasses.Field

def make_representer(name: str):
return lambda dumper, data: dumper.represent_mapping(
tag=DATACLASS_YAML_TAG_PREFIX + name,
mapping={
field.name: getattr(data, field.name)
for field in dataclasses.fields(data)
if field.init
},
)
def representer(dumper, data):
if dataclasses.is_dataclass(data):
return dumper.represent_mapping(
tag=DATACLASS_YAML_TAG_PREFIX + name,
mapping={
field.name: getattr(data, field.name)
for field in dataclasses.fields(data)
if field.init
},
)
elif isinstance(data, enum.Enum):
return dumper.represent_scalar(
tag=ENUM_YAML_TAG_PREFIX + name, value=data.name
)

return representer

for typ, name in zip(contained_types, contained_type_names):
DataclassDumper.add_representer(typ, make_representer(name))
Expand All @@ -136,7 +163,14 @@ def from_yaml(
return out


def _timestamp() -> str:
"""Get a current timestamp as a string. Example format: `2021-11-05-15:46:32`."""
return datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")


def to_yaml(instance: Any) -> str:
"""Serialize a dataclass; returns a yaml-compatible string that can be deserialized
via `dcargs.from_yaml()`."""
return yaml.dump(instance, Dumper=_make_dumper(instance))
return f"# YAML generated via dcargs, at {_timestamp()}.\n" + yaml.dump(
instance, Dumper=_make_dumper(instance)
)
1 change: 1 addition & 0 deletions examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ class ExperimentConfig:
if __name__ == "__main__":
config = dcargs.parse(ExperimentConfig, description=__doc__)
print(config)
print(dcargs.to_yaml(config))
45 changes: 26 additions & 19 deletions tests/test_generics_and_serialization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
import enum
from typing import Generic, Type, TypeVar, Union

import pytest
Expand All @@ -15,12 +16,17 @@ def _check_serialization_identity(cls: Type[T], instance: T) -> None:
ScalarType = TypeVar("ScalarType")


class CoordinateFrame(enum.Enum):
WORLD = enum.auto()
CAMERA = enum.auto()


@dataclasses.dataclass
class Point3(Generic[ScalarType]):
x: ScalarType
y: ScalarType
z: ScalarType
frame_id: str
frame: CoordinateFrame


def test_simple_generic():
Expand All @@ -38,20 +44,21 @@ class SimpleGeneric:
"2.2",
"--point-continuous.z",
"3.2",
"--point-continuous.frame-id",
"world",
"--point-continuous.frame",
"WORLD",
"--point-discrete.x",
"1",
"--point-discrete.y",
"2",
"--point-discrete.z",
"3",
"--point-discrete.frame-id",
"world",
"--point-discrete.frame",
"WORLD",
],
)
assert parsed_instance == SimpleGeneric(
Point3(1.2, 2.2, 3.2, "world"), Point3(1, 2, 3, "world")
Point3(1.2, 2.2, 3.2, CoordinateFrame.WORLD),
Point3(1, 2, 3, CoordinateFrame.WORLD),
)
_check_serialization_identity(SimpleGeneric, parsed_instance)

Expand All @@ -66,16 +73,16 @@ class SimpleGeneric:
"2.2",
"--point-continuous.z",
"3.2",
"--point-continuous.frame-id",
"world",
"--point-continuous.frame",
"WORLD",
"--point-discrete.x",
"1.5",
"--point-discrete.y",
"2.5",
"--point-discrete.z",
"3.5",
"--point-discrete.frame-id",
"world",
"--point-discrete.frame",
"WORLD",
],
)

Expand All @@ -96,30 +103,30 @@ class Triangle(Generic[ScalarType]):
"1.2",
"--a.z",
"1.3",
"--a.frame-id",
"world",
"--a.frame",
"WORLD",
"--b.x",
"1.0",
"--b.y",
"1.2",
"--b.z",
"1.3",
"--b.frame-id",
"world",
"--b.frame",
"WORLD",
"--c.x",
"1.0",
"--c.y",
"1.2",
"--c.z",
"1.3",
"--c.frame-id",
"world",
"--c.frame",
"WORLD",
],
)
assert parsed_instance == Triangle(
Point3(1.0, 1.2, 1.3, "world"),
Point3(1.0, 1.2, 1.3, "world"),
Point3(1.0, 1.2, 1.3, "world"),
Point3(1.0, 1.2, 1.3, CoordinateFrame.WORLD),
Point3(1.0, 1.2, 1.3, CoordinateFrame.WORLD),
Point3(1.0, 1.2, 1.3, CoordinateFrame.WORLD),
)
_check_serialization_identity(Triangle[float], parsed_instance)

Expand Down

0 comments on commit 7ff0e00

Please sign in to comment.