-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from klauer/enh_deps_and_serialization
ENH: dependencies, first pass at serialization, and more
- Loading branch information
Showing
11 changed files
with
1,735 additions
and
1,396 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
""" | ||
Serialization helpers for apischema, an optional dependency. | ||
""" | ||
# Largely based on issue discussions regarding tagged unions. | ||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from collections.abc import Callable, Iterator | ||
from types import new_class | ||
from typing import Any, Dict, Generic, List, Tuple, TypeVar, get_type_hints | ||
|
||
import lark | ||
from apischema import deserializer, serializer, type_name | ||
from apischema.conversions import Conversion | ||
from apischema.metadata import conversion | ||
from apischema.objects import object_deserialization | ||
from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged | ||
from apischema.typing import get_origin | ||
from apischema.utils import to_pascal_case | ||
|
||
_alternative_constructors: Dict[type, List[Callable]] = defaultdict(list) | ||
Func = TypeVar("Func", bound=Callable) | ||
|
||
|
||
def alternative_constructor(func: Func) -> Func: | ||
"""Alternative constructor for a given type.""" | ||
return_type = get_type_hints(func)["return"] | ||
_alternative_constructors[get_origin(return_type) or return_type].append(func) | ||
return func | ||
|
||
|
||
def get_all_subclasses(cls: type) -> Iterator[type]: | ||
"""Recursive implementation of type.__subclasses__""" | ||
for sub_cls in cls.__subclasses__(): | ||
yield sub_cls | ||
yield from get_all_subclasses(sub_cls) | ||
|
||
|
||
Cls = TypeVar("Cls", bound=type) | ||
|
||
|
||
def _get_generic_name_factory(cls: type, *args: type): | ||
def _capitalized(name: str) -> str: | ||
return name[0].upper() + name[1:] | ||
|
||
return "".join((cls.__name__, *(_capitalized(arg.__name__) for arg in args))) | ||
|
||
|
||
generic_name = type_name(_get_generic_name_factory) | ||
|
||
|
||
def as_tagged_union(cls: Cls) -> Cls: | ||
""" | ||
Tagged union decorator, to be used on base class. | ||
Supports generics as well, with names generated by way of | ||
`_get_generic_name_factory`. | ||
""" | ||
params = tuple(getattr(cls, "__parameters__", ())) | ||
tagged_union_bases: Tuple[type, ...] = (TaggedUnion,) | ||
|
||
# Generic handling is here: | ||
if params: | ||
tagged_union_bases = (TaggedUnion, Generic[params]) | ||
generic_name(cls) | ||
prev_init_subclass = getattr(cls, "__init_subclass__", None) | ||
|
||
def __init_subclass__(cls, **kwargs): | ||
if prev_init_subclass is not None: | ||
prev_init_subclass(**kwargs) | ||
generic_name(cls) | ||
|
||
cls.__init_subclass__ = classmethod(__init_subclass__) | ||
|
||
def with_params(cls: type) -> Any: | ||
"""Specify type of Generic if set.""" | ||
return cls[params] if params else cls | ||
|
||
def serialization() -> Conversion: | ||
""" | ||
Define the serializer Conversion for the tagged union. | ||
source is the base ``cls`` (or ``cls[T]``). | ||
target is the new tagged union class ``TaggedUnion`` which gets the | ||
dictionary {cls.__name__: obj} as its arguments. | ||
""" | ||
annotations = { | ||
# Assume that subclasses have same generic parameters than cls | ||
sub.__name__: Tagged[with_params(sub)] | ||
for sub in get_all_subclasses(cls) | ||
} | ||
namespace = {"__annotations__": annotations} | ||
tagged_union = new_class( | ||
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) | ||
) | ||
return Conversion( | ||
lambda obj: tagged_union(**{obj.__class__.__name__: obj}), | ||
source=with_params(cls), | ||
target=with_params(tagged_union), | ||
# Conversion must not be inherited because it would lead to | ||
# infinite recursion otherwise | ||
inherited=False, | ||
) | ||
|
||
def deserialization() -> Conversion: | ||
""" | ||
Define the deserializer Conversion for the tagged union. | ||
Allows for alternative standalone constructors as per the apischema | ||
example. | ||
""" | ||
annotations: dict[str, Any] = {} | ||
namespace: dict[str, Any] = {"__annotations__": annotations} | ||
for sub in get_all_subclasses(cls): | ||
annotations[sub.__name__] = Tagged[with_params(sub)] | ||
for constructor in _alternative_constructors.get(sub, ()): | ||
# Build the alias of the field | ||
alias = to_pascal_case(constructor.__name__) | ||
# object_deserialization uses get_type_hints, but the constructor | ||
# return type is stringified and the class not defined yet, | ||
# so it must be assigned manually | ||
constructor.__annotations__["return"] = with_params(sub) | ||
# Use object_deserialization to wrap constructor as deserializer | ||
deserialization = object_deserialization(constructor, generic_name) | ||
# Add constructor tagged field with its conversion | ||
annotations[alias] = Tagged[with_params(sub)] | ||
namespace[alias] = Tagged(conversion(deserialization=deserialization)) | ||
# Create the deserialization tagged union class | ||
tagged_union = new_class( | ||
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) | ||
) | ||
return Conversion( | ||
lambda obj: get_tagged(obj)[1], | ||
source=with_params(tagged_union), | ||
target=with_params(cls), | ||
) | ||
|
||
deserializer(lazy=deserialization, target=cls) | ||
serializer(lazy=serialization, source=cls) | ||
return cls | ||
|
||
|
||
@serializer | ||
def token_serializer(token: lark.Token) -> List[str]: | ||
return [token.type, token.value] | ||
|
||
|
||
@deserializer | ||
def token_deserializer(parts: List[str]) -> lark.Token: | ||
return lark.Token(*parts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
import os | ||
|
||
BLARK_TWINCAT_ROOT = os.environ.get("BLARK_TWINCAT_ROOT", ".") |
Oops, something went wrong.