Skip to content

Commit

Permalink
Bugfix: deserialize Union of generics (#11)
Browse files Browse the repository at this point in the history
Fixes `deserialize to properly handle `Union` types that consist of generic types. Crucially, this covers the use of container types (e.g. `List`, `Dict`, `Sequence`, `Tuple`, `Mapping`, etc.) in `Union`s. New tests have been added to explicitly cover this case.

Version `0.3.2`
  • Loading branch information
malcolmgreaves authored Jan 20, 2022
1 parent 53259ef commit 5d0f77d
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 55 deletions.
65 changes: 65 additions & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Use the latest 2.1 version of CircleCI pipeline process engine.
# See: https://circleci.com/docs/2.0/configuration-reference
version: 2.1

# # Orbs are reusable packages of CircleCI configuration that you may share across projects, enabling you to create encapsulated, parameterized commands, jobs, and executors that can be used across multiple projects.
# # See: https://circleci.com/docs/2.0/orb-intro/
# orbs:
# # The python orb contains a set of prepackaged CircleCI configuration you can use repeatedly in your configuration files
# # Orb commands and jobs help you with common scripting around a language/tool
# # so you dont have to copy and paste it everywhere.
# # See the orb documentation here: https://circleci.com/developer/orbs/orb/circleci/python
# python: circleci/[email protected]

# Define a job to be invoked later in a workflow.
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
jobs:
build-and-test: # This is the name of the job, feel free to change it to better match what you're trying to do!
# These next lines defines a Docker executors: https://circleci.com/docs/2.0/executor-types/
# You can specify an image from Dockerhub or use one of the convenience images from CircleCI's Developer Hub
# A list of available CircleCI Docker convenience images are available here: https://circleci.com/developer/images/image/cimg/python
# The executor is the environment in which the steps below will be executed - below will use a python 3.8 container
# Change the version below to your required version of python
docker:
- image: cimg/python:3.8
# Checkout the code as the first step. This is a dedicated CircleCI step.
# The python orb's install-packages step will install the dependencies from a Pipfile via Pipenv by default.
# Here we're making sure we use just use the system-wide pip. By default it uses the project root's requirements.txt.
# Then run your tests!
# CircleCI will report the results back to your VCS provider.
steps:
- checkout
- run:
command: |
rm .python-version
pip install poetry
poetry install
poetry build
poetry run black --check
poetry run flake8 --max-line-length=100 --ignore=E501,W293,E303,W291,W503,E203,E731,E231,E721,E722,E741 .
poetry run mypy --ignore-missing-imports --follow-imports=silent --show-column-numbers --warn-unreachable .
poetry run pytest -v --cov core_utils
poetry run coverage html
poetry run coveralls
# - python/install-packages:
# pkg-manager: poetry
# # app-dir: ~/project/package-directory/ # If you're requirements.txt isn't in the root directory.
# # pip-dependency-file: test-requirements.txt # if you have a different name for your requirements file, maybe one that combines your runtime and test requirements.
# - run:
# name: Run tests
# # This assumes pytest is installed via the install-package step above
# command: |
# poetry run black --check
# poetry run flake8 --max-line-length=100 --ignore=E501,W293,E303,W291,W503,E203,E731,E231,E721,E722,E741 .
# poetry run mypy --ignore-missing-imports --follow-imports=silent --show-column-numbers --warn-unreachable .
# poetry run pytest -v --cov core_utils
# poetry run coverage html
# poetry run coveralls

# Invoke jobs via workflows
# See: https://circleci.com/docs/2.0/configuration-reference/#workflows
workflows:
ci: # This is the name of the workflow, feel free to change it to better match your workflow.
# Inside the workflow, you define the jobs you want to run.
jobs:
- build-and-test
17 changes: 0 additions & 17 deletions .travis.yml

This file was deleted.

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# `pywise`
[![PyPI version](https://badge.fury.io/py/pywise.svg)](https://badge.fury.io/py/pywise) [![Build Status](https://travis-ci.org/malcolmgreaves/pywise.svg?branch=main)](https://travis-ci.org/malcolmgreaves/pywise) [![Coverage Status](https://coveralls.io/repos/github/malcolmgreaves/pywise/badge.svg?branch=main)](https://coveralls.io/github/malcolmgreaves/pywise?branch=main)
[![PyPI version](https://badge.fury.io/py/pywise.svg)](https://badge.fury.io/py/pywise) [![CircleCI](https://circleci.com/gh/malcolmgreaves/pywise/tree/main.svg?style=svg)](https://circleci.com/gh/malcolmgreaves/pywise/tree/main) [![Coverage Status](https://coveralls.io/repos/github/malcolmgreaves/pywise/badge.svg?branch=main)](https://coveralls.io/github/malcolmgreaves/pywise?branch=main)

Contains functions that provide general utility and build upon the Python 3 standard library. It has no external dependencies.
- `serialization`: serialization & deserialization for `NamedTuple`-deriving & `@dataclass` decorated classes
Expand Down
13 changes: 8 additions & 5 deletions core_utils/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from importlib import import_module
from typing import _GenericAlias, Any, Tuple, Optional, Type, TypeVar # type: ignore
from typing import _GenericAlias, Any, Tuple, Optional, Type, TypeVar, get_args # type: ignore


def type_name(t: type, keep_main: bool = True) -> str:
Expand All @@ -19,7 +19,7 @@ def type_name(t: type, keep_main: bool = True) -> str:

if str(t).startswith("typing.Union"):
try:
args = t.__args__ # type: ignore
args = get_args(t)
if len(args) == 2 and args[1] == type(None): # noqa: E721
# an Optional type is equivalent to Union[T, None]
return f"typing.Optional[{type_name(args[0])}]"
Expand All @@ -36,9 +36,12 @@ def type_name(t: type, keep_main: bool = True) -> str:
full_name = f"{mod}.{t.__name__}"
try:
# generic parameters ?
args = tuple(map(type_name, t.__args__)) # type: ignore
a = ", ".join(args)
complete_type_name = f"{full_name}[{a}]"
args = tuple(map(type_name, get_args(t))) # type: ignore
if len(args) > 0:
a = ", ".join(args)
complete_type_name: str = f"{full_name}[{a}]"
else:
complete_type_name = full_name
except Exception:
complete_type_name = full_name

Expand Down
19 changes: 11 additions & 8 deletions core_utils/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Type, Iterable, Union, Any, Mapping, Sequence
from typing import Type, Iterable, Union, Any, Mapping, Sequence, get_args, cast
from dataclasses import is_dataclass

from core_utils.common import type_name, checkable_type
Expand Down Expand Up @@ -51,22 +51,25 @@ def _dict_type(t: type):
checkable_t: Type = checkable_type(t)
if issubclass(checkable_t, Mapping):
try:
key_t: type = t.__args__[0] # type: ignore
val_t: type = t.__args__[1] # type: ignore
_args = get_args(t)
key_t: type = cast(type, _args[0])
val_t: type = cast(type, _args[1])
except Exception as e:
raise TypeError(
f"Could not extract key & value types from dict type: '{t}'"
) from e
k = _dict_type(key_t)
v = _dict_type(val_t)
return {k: v}
else:
k = _dict_type(key_t)
v = _dict_type(val_t)
return {k: v}

elif issubclass(checkable_t, Iterable) and t != str:
try:
inner_t: type = t.__args__[0] # type: ignore
inner_t: type = cast(type, get_args(t)[0])
except Exception as e:
raise TypeError(
f"Could not extract inner type from iterable type: '{t}'"
) from e
return [_dict_type(inner_t)]
else:
return [_dict_type(inner_t)]
return tn
56 changes: 33 additions & 23 deletions core_utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
Optional,
Iterator,
Sequence,
get_origin,
get_args,
Union,
cast,
)
from dataclasses import dataclass, is_dataclass, Field

Expand Down Expand Up @@ -120,7 +124,6 @@ def deserialize(
NOTE: If using :param:`custom` for generic types, you *must* have unique instances for each possible
type parametrization.
"""

if custom is not None and type_value in custom:
return custom[type_value](value)

Expand All @@ -145,30 +148,42 @@ def deserialize(
if value is None:
return None
else:
return deserialize(type_value.__args__[0], value, custom)
inner_type = cast(type, get_args(type_value)[0])
return deserialize(inner_type, value, custom)

# NOTE: Need to have type_value instead of checking_type_value here !
elif _is_union(type_value):
for possible_type in type_value.__args__:
# try to deserialize the value using one of its
# possible types
for _p in get_args(type_value):
possible_type = cast(type, _p)
# determine if the value could be deserialized into one
# of the union's listed types
# try:
# # for "concrete" types
# ok_to_deserialize_into: bool = isinstance(value, possible_type)
# except Exception:
# # for generics, e.g. collection types
# ok_to_deserialize_into = isinstance(value, get_origin(possible_type))
# if ok_to_deserialize_into:
# return deserialize(possible_type, value, custom)
try:
return deserialize(possible_type, value, custom)
except Exception:
pass
continue
raise FieldDeserializeFail(
field_name="", expected_type=type_value, actual_value=value
)

elif issubclass(checking_type_value, Mapping):
k_type, v_type = type_value.__args__ # type: ignore
_args = get_args(type_value)
k_type = cast(type, _args[0])
v_type = cast(type, _args[1])
return {
deserialize(k_type, k, custom): deserialize(v_type, v, custom)
for k, v in value.items()
}

elif issubclass(checking_type_value, Tuple) and checking_type_value != str: # type: ignore
tuple_type_args = type_value.__args__
tuple_type_args = get_args(type_value)
converted = map(
lambda type_val_pair: deserialize(
type_val_pair[0], type_val_pair[1], custom
Expand All @@ -178,7 +193,10 @@ def deserialize(
return tuple(converted)

elif issubclass(checking_type_value, Iterable) and checking_type_value != str:
(i_type,) = type_value.__args__ # type: ignore
# special case: fail-fast on trying to treat a dict as list-like
if isinstance(value, dict):
raise FieldDeserializeFail("", type_value, value)
i_type = cast(type, get_args(type_value)[0])
converted = map(lambda x: deserialize(i_type, x, custom), value)
if issubclass(checking_type_value, Set):
return set(converted)
Expand Down Expand Up @@ -365,17 +383,17 @@ def _align_generic_concrete(
then the generics will be handled appropriately.
"""
try:
origin = data_type_with_generics.__origin__
origin = data_type_with_generics.__origin__ # type: ignore
if issubclass(origin, Sequence):
generics = [TypeVar("T")] # type: ignore
values = data_type_with_generics.__args__
values = get_args(data_type_with_generics)
elif issubclass(origin, Mapping):
generics = [TypeVar("KT"), TypeVar("VT_co")] # type: ignore
values = data_type_with_generics.__args__
values = get_args(data_type_with_generics)
else:
# should be a dataclass
generics = origin.__parameters__ # type: ignore
values = data_type_with_generics.__args__ # type: ignore
values = get_args(data_type_with_generics) # type: ignore
for g, v in zip(generics, values):
yield g, v
except AttributeError as e:
Expand Down Expand Up @@ -547,7 +565,7 @@ def _is_optional(t: type) -> bool:
"""Evaluates to true iff the input is a type that is equivalent to an `Optional`.
"""
try:
type_args = t.__args__ # type: ignore
type_args = get_args(t)
only_one_none_type = (
len(list(filter(lambda x: x == type(None), type_args))) == 1 # type: ignore
)
Expand All @@ -559,12 +577,4 @@ def _is_optional(t: type) -> bool:
def _is_union(t: type) -> bool:
"""Evaluates to true iff the input is a union (not an Optional) type.
"""
try:
type_args = t.__args__ # type: ignore
return (
not _is_optional(t)
and all(map(lambda x: isinstance(x, type), type_args))
and type_name(t).startswith("typing.Union")
)
except Exception:
return False
return get_origin(t) is Union and not _is_optional(t)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pywise"
version = "0.3.1"
version = "0.3.2"
description = "Robust serialization support for NamedTuple & @dataclass data types."
authors = ["Malcolm Greaves <[email protected]>"]
homepage = "https://github.com/malcolmgreaves/pywise"
Expand Down
41 changes: 41 additions & 0 deletions tests/test_deserialize_unions_with_collections.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from core_utils.serialization import serialize, deserialize
from dataclasses import dataclass
from typing import Union, Sequence, Mapping


@dataclass(frozen=True)
class A:
a: int


@dataclass(frozen=True)
class B:
b: str


@dataclass(frozen=True)
class C:
c: Union[A, B, Sequence[A], Sequence[B], Mapping[int, str]]


@pytest.mark.parametrize(
"c_input",
[
# "simple" dataclass cases
C(A(1)),
C(B("hello world")),
# list cases
C([A(1)]),
C([B("hello world")]),
C([A(1), A(2), A(4)]),
C([B("a"), B("b"), B("c")]),
C([]),
# dict cases
C({0: "hello", 1: "world", 2: "how", 3: "are", 4: "you"}),
C(dict()),
],
)
def test_deserialize_dataclass_with_union_of_collections(c_input: C) -> None:
assert deserialize(C, serialize(c_input)) == c_input

0 comments on commit 5d0f77d

Please sign in to comment.