diff --git a/dcargs/extras/_serialization.py b/dcargs/extras/_serialization.py index c4ba73a4..c5c433a4 100644 --- a/dcargs/extras/_serialization.py +++ b/dcargs/extras/_serialization.py @@ -2,10 +2,11 @@ import dataclasses import enum +import functools from typing import IO, Any, Optional, Set, Type, TypeVar, Union import yaml -from typing_extensions import get_origin +from typing_extensions import get_args, get_origin from .. import _fields, _resolver @@ -48,15 +49,21 @@ def _get_contained_special_types_from_type( contained_dataclasses = {cls} def handle_type(typ) -> Set[Type]: + print(typ) + # Handle dataclasses. if _resolver.is_dataclass(typ) and typ not in parent_contained_dataclasses: return _get_contained_special_types_from_type( typ, _parent_contained_dataclasses=contained_dataclasses | parent_contained_dataclasses, ) + + # Handle enums. elif type(typ) is enum.EnumMeta: return {typ} - return set() + + # Handle Union, Annotated, List, etc. No-op when there are no args. + return functools.reduce(set.union, map(handle_type, get_args(typ)), set()) # Handle generics. for typ in type_from_typevar.values(): diff --git a/pyproject.toml b/pyproject.toml index 7dd66399..cda21fa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "dcargs" -version = "0.2.6" +version = "0.2.7" description = "Strongly typed, zero-effort CLI interfaces" authors = ["brentyi "] include = ["./dcargs/**/*"] diff --git a/tests/test_generics_and_serialization.py b/tests/test_generics_and_serialization.py index 0ac23f7c..c8d4e65f 100644 --- a/tests/test_generics_and_serialization.py +++ b/tests/test_generics_and_serialization.py @@ -5,6 +5,7 @@ from typing import Generic, List, Tuple, Type, TypeVar, Union import pytest +from typing_extensions import Annotated import dcargs @@ -346,3 +347,43 @@ def main(x: ActualParentClass[int] = ChildClass(5, 5)) -> ActualParentClass: return x assert dcargs.cli(main, args="--x.x 3".split(" ")) == ChildClass(3, 5, 3) + + +def test_pculbertson(): + # https://github.com/brentyi/dcargs/issues/7 + from typing import Union + + @dataclasses.dataclass + class TypeA: + data: int + + @dataclasses.dataclass + class TypeB: + data: int + + @dataclasses.dataclass + class Wrapper: + subclass: Union[TypeA, TypeB] = TypeA(1) + + wrapper1 = Wrapper() # Create Wrapper object. + wrapper2 = dcargs.extras.from_yaml( + Wrapper, dcargs.extras.to_yaml(wrapper1) + ) # Errors, no constructor for TypeA + + +def test_annotated(): + # https://github.com/brentyi/dcargs/issues/7 + from typing import Union + + @dataclasses.dataclass + class TypeA: + data: int + + @dataclasses.dataclass + class Wrapper: + subclass: Annotated[int, TypeA] = TypeA(1) + + wrapper1 = Wrapper() # Create Wrapper object. + wrapper2 = dcargs.extras.from_yaml( + Wrapper, dcargs.extras.to_yaml(wrapper1) + ) # Errors, no constructor for TypeA