diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..5c8b979 Binary files /dev/null and b/.DS_Store differ diff --git a/dataclasses_jsonschema/__init__.py b/dataclasses_jsonschema/__init__.py index c89fb5b..42b0c24 100644 --- a/dataclasses_jsonschema/__init__.py +++ b/dataclasses_jsonschema/__init__.py @@ -5,6 +5,7 @@ from decimal import Decimal from ipaddress import IPv4Address, IPv6Address from typing import Optional, Type, Union, Any, Dict, Tuple, List, Callable, TypeVar +from typing_extensions import get_args import re from datetime import datetime, date from dataclasses import fields, is_dataclass, Field, MISSING, dataclass, asdict @@ -438,7 +439,9 @@ def _decode_field(cls, field: str, field_type: Any, value: Any) -> Any: field_type_name = cls._get_field_type_name(field_type) # Note: Only literal types composed of primitive values are currently supported if type(value) in JSON_ENCODABLE_TYPES and (field_type in JSON_ENCODABLE_TYPES or is_literal(field_type)): - if is_literal(field_type): + if is_literal(field_type) and value not in get_args(field_type): + raise ValueError('Literal value is not in allowed set of values for type.') + elif is_literal(field_type): def decoder(_, __, val): return val else: def decoder(_, ft, val): return ft(val) diff --git a/tests/test_core.py b/tests/test_core.py index 9f390d7..d5745fd 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1052,3 +1052,26 @@ class Person(JsonSchemaMixin): data = Person(name="Joe", pet=Dog(breed="Pug")).to_dict() p = Person.from_dict(data) assert p.pet == Dog(breed="Pug") + + +def test_decode_literal(): + @dataclass + class Foo(JsonSchemaMixin): + common_field: int + name: Literal['Foo'] = 'Foo' + + @dataclass + class Bar(JsonSchemaMixin): + common_field: int + name: Literal['Bar'] = 'Bar' + other_field: Optional[int] = None + + @dataclass + class Baz(JsonSchemaMixin): + my_foo_bar: Union[Bar, Foo] + + decoded = Baz.from_dict(Baz(my_foo_bar=Foo(common_field=1)).to_dict()) + + # Even though Bar comes first in the Union for 'my_foor_bar', the literal check on 'name' should + # verify that this is a Foo and not assign it to a Bar + assert type(decoded.my_foo_bar) == Foo