Skip to content

Commit

Permalink
Pr/aleksander pawlak/170 (#186)
Browse files Browse the repository at this point in the history
* Fix forward references

* Fix forward references

* Add comment; Remove unused imports

* Remove unused imports

Fix flake8

Fix deprecated usage of 'self'

* Add usage inspect.stack to get locals from scope of class declaration

* Fix proper passing of class definition frame to class_schema of nested fields

* Change lambda to partial

* Add params im test_class_schema

* Fix type annotation in dataclass decorator. Add better doc

* [README] Add documentation about Union types (#184)

* [README] Update README.md to address issue #183

* Update README.md

* fix typo in README.md

Co-authored-by: Ophir LOJKINE <[email protected]>

* Fix use of typing.NewType in Python 3.10 (#180)

* Fix use of typing.NewType in Python 3.10

https://docs.python.org/3.10/library/typing.html#typing.NewType

"Changed in version 3.10: NewType is now a class rather than a function."

* Add Python 3.10 to GitHub workflow test matrix

Note the use of strings: otherwise YAML thinks we want to test Python 3.1

* Pin types-dataclasses version for Python 3.6

It looks broken for Python 3.6 due to python/typeshed@a40d79a

* fix ci (#185)

* Update dev dependencies
* mypy has gotten smarter, and it tries to do an automatic cast that breaks our code

* drop support for pre-commit hooks on 3.6

* downgrade types-dataclass

Co-authored-by: Adrian Dankiv <[email protected]>
Co-authored-by: Adrian Dankiv <[email protected]>
Co-authored-by: Aleksander Pawlak <[email protected]>
Co-authored-by: John Bodley <[email protected]>
Co-authored-by: Nicolas Noirbent <[email protected]>
  • Loading branch information
6 people authored Apr 21, 2022
1 parent 228f0e8 commit fac0b6d
Show file tree
Hide file tree
Showing 4 changed files with 311 additions and 24 deletions.
109 changes: 85 additions & 24 deletions marshmallow_dataclass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ class User:
import collections.abc
import dataclasses
import inspect
import types
import warnings
from enum import EnumMeta
from functools import lru_cache
from functools import lru_cache, partial
from typing import (
Any,
Callable,
Expand All @@ -53,6 +54,7 @@ class User:
TypeVar,
Union,
cast,
get_type_hints,
overload,
Sequence,
FrozenSet,
Expand All @@ -61,6 +63,9 @@ class User:
import marshmallow
import typing_inspect

from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute


__all__ = ["dataclass", "add_schema", "class_schema", "field_for_schema", "NewType"]

NoneType = type(None)
Expand All @@ -83,6 +88,7 @@ def dataclass(
unsafe_hash: bool = False,
frozen: bool = False,
base_schema: Optional[Type[marshmallow.Schema]] = None,
cls_frame: Optional[types.FrameType] = None,
) -> Type[_U]:
...

Expand All @@ -96,6 +102,7 @@ def dataclass(
unsafe_hash: bool = False,
frozen: bool = False,
base_schema: Optional[Type[marshmallow.Schema]] = None,
cls_frame: Optional[types.FrameType] = None,
) -> Callable[[Type[_U]], Type[_U]]:
...

Expand All @@ -112,12 +119,15 @@ def dataclass(
unsafe_hash: bool = False,
frozen: bool = False,
base_schema: Optional[Type[marshmallow.Schema]] = None,
cls_frame: Optional[types.FrameType] = None,
) -> Union[Type[_U], Callable[[Type[_U]], Type[_U]]]:
"""
This decorator does the same as dataclasses.dataclass, but also applies :func:`add_schema`.
It adds a `.Schema` attribute to the class object
:param base_schema: marshmallow schema used as a base class when deriving dataclass schema
:param cls_frame: frame of cls definition, used to obtain locals with other classes definitions.
If None is passed the caller frame will be treated as cls_frame
>>> @dataclass
... class Artist:
Expand All @@ -140,9 +150,10 @@ def dataclass(
dc = dataclasses.dataclass( # type: ignore
_cls, repr=repr, eq=eq, order=order, unsafe_hash=unsafe_hash, frozen=frozen
)
cls_frame = cls_frame or inspect.stack()[1][0]
if _cls is None:
return lambda cls: add_schema(dc(cls), base_schema)
return add_schema(dc, base_schema)
return lambda cls: add_schema(dc(cls), base_schema, cls_frame=cls_frame)
return add_schema(dc, base_schema, cls_frame=cls_frame)


@overload
Expand All @@ -159,18 +170,21 @@ def add_schema(

@overload
def add_schema(
_cls: Type[_U], base_schema: Type[marshmallow.Schema] = None
_cls: Type[_U],
base_schema: Type[marshmallow.Schema] = None,
cls_frame: types.FrameType = None,
) -> Type[_U]:
...


def add_schema(_cls=None, base_schema=None):
def add_schema(_cls=None, base_schema=None, cls_frame=None):
"""
This decorator adds a marshmallow schema as the 'Schema' attribute in a dataclass.
It uses :func:`class_schema` internally.
:param type _cls: The dataclass to which a Schema should be added
:param base_schema: marshmallow schema used as a base class when deriving dataclass schema
:param cls_frame: frame of cls definition
>>> class BaseSchema(marshmallow.Schema):
... def on_bind_field(self, field_name, field_obj):
Expand All @@ -187,20 +201,27 @@ def add_schema(_cls=None, base_schema=None):

def decorator(clazz: Type[_U]) -> Type[_U]:
# noinspection PyTypeHints
clazz.Schema = class_schema(clazz, base_schema) # type: ignore
clazz.Schema = lazy_class_attribute( # type: ignore
partial(class_schema, clazz, base_schema, cls_frame),
"Schema",
clazz.__name__,
)
return clazz

return decorator(_cls) if _cls else decorator


def class_schema(
clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
clazz: type,
base_schema: Optional[Type[marshmallow.Schema]] = None,
clazz_frame: types.FrameType = None,
) -> Type[marshmallow.Schema]:
"""
Convert a class to a marshmallow schema
:param clazz: A python class (may be a dataclass)
:param base_schema: marshmallow schema used as a base class when deriving dataclass schema
:param clazz_frame: frame of cls definition
:return: A marshmallow Schema corresponding to the dataclass
.. note::
Expand Down Expand Up @@ -315,12 +336,14 @@ def class_schema(
"""
if not dataclasses.is_dataclass(clazz):
clazz = dataclasses.dataclass(clazz)
return _internal_class_schema(clazz, base_schema)
return _internal_class_schema(clazz, base_schema, clazz_frame)


@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
def _internal_class_schema(
clazz: type, base_schema: Optional[Type[marshmallow.Schema]] = None
clazz: type,
base_schema: Optional[Type[marshmallow.Schema]] = None,
clazz_frame: types.FrameType = None,
) -> Type[marshmallow.Schema]:
try:
# noinspection PyDataclass
Expand All @@ -339,7 +362,7 @@ def _internal_class_schema(
"****** WARNING ******"
)
created_dataclass: type = dataclasses.dataclass(clazz)
return _internal_class_schema(created_dataclass, base_schema)
return _internal_class_schema(created_dataclass, base_schema, clazz_frame)
except Exception:
raise TypeError(
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
Expand All @@ -351,12 +374,20 @@ def _internal_class_schema(
for k, v in inspect.getmembers(clazz)
if hasattr(v, "__marshmallow_hook__") or k in MEMBERS_WHITELIST
}

# Update the schema members to contain marshmallow fields instead of dataclass fields
type_hints = get_type_hints(
clazz, localns=clazz_frame.f_locals if clazz_frame else None
)
attributes.update(
(
field.name,
field_for_schema(
field.type, _get_field_default(field), field.metadata, base_schema
type_hints[field.name],
_get_field_default(field),
field.metadata,
base_schema,
clazz_frame,
),
)
for field in fields
Expand All @@ -381,6 +412,7 @@ def _field_by_supertype(
newtype_supertype: Type,
metadata: dict,
base_schema: Optional[Type[marshmallow.Schema]],
typ_frame: Optional[types.FrameType],
) -> marshmallow.fields.Field:
"""
Return a new field for fields based on a super field. (Usually spawned from NewType)
Expand Down Expand Up @@ -411,6 +443,7 @@ def _field_by_supertype(
metadata=metadata,
default=default,
base_schema=base_schema,
typ_frame=typ_frame,
)


Expand All @@ -432,7 +465,10 @@ def _generic_type_add_any(typ: type) -> type:


def _field_for_generic_type(
typ: type, base_schema: Optional[Type[marshmallow.Schema]], **metadata: Any
typ: type,
base_schema: Optional[Type[marshmallow.Schema]],
typ_frame: Optional[types.FrameType],
**metadata: Any,
) -> Optional[marshmallow.fields.Field]:
"""
If the type is a generic interface, resolve the arguments and construct the appropriate Field.
Expand All @@ -444,7 +480,9 @@ def _field_for_generic_type(
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}

if origin in (list, List):
child_type = field_for_schema(arguments[0], base_schema=base_schema)
child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
list_type = cast(
Type[marshmallow.fields.List],
type_mapping.get(List, marshmallow.fields.List),
Expand All @@ -453,25 +491,32 @@ def _field_for_generic_type(
if origin in (collections.abc.Sequence, Sequence):
from . import collection_field

child_type = field_for_schema(arguments[0], base_schema=base_schema)
child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
if origin in (set, Set):
from . import collection_field

child_type = field_for_schema(arguments[0], base_schema=base_schema)
child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
return collection_field.Set(
cls_or_instance=child_type, frozen=False, **metadata
)
if origin in (frozenset, FrozenSet):
from . import collection_field

child_type = field_for_schema(arguments[0], base_schema=base_schema)
child_type = field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
)
return collection_field.Set(
cls_or_instance=child_type, frozen=True, **metadata
)
if origin in (tuple, Tuple):
children = tuple(
field_for_schema(arg, base_schema=base_schema) for arg in arguments
field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame)
for arg in arguments
)
tuple_type = cast(
Type[marshmallow.fields.Tuple],
Expand All @@ -483,8 +528,12 @@ def _field_for_generic_type(
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
return dict_type(
keys=field_for_schema(arguments[0], base_schema=base_schema),
values=field_for_schema(arguments[1], base_schema=base_schema),
keys=field_for_schema(
arguments[0], base_schema=base_schema, typ_frame=typ_frame
),
values=field_for_schema(
arguments[1], base_schema=base_schema, typ_frame=typ_frame
),
**metadata,
)
elif typing_inspect.is_union_type(typ):
Expand All @@ -497,7 +546,10 @@ def _field_for_generic_type(
subtypes = [t for t in arguments if t is not NoneType] # type: ignore
if len(subtypes) == 1:
return field_for_schema(
subtypes[0], metadata=metadata, base_schema=base_schema
subtypes[0],
metadata=metadata,
base_schema=base_schema,
typ_frame=typ_frame,
)
from . import union_field

Expand All @@ -506,7 +558,10 @@ def _field_for_generic_type(
(
subtyp,
field_for_schema(
subtyp, metadata={"required": True}, base_schema=base_schema
subtyp,
metadata={"required": True},
base_schema=base_schema,
typ_frame=typ_frame,
),
)
for subtyp in subtypes
Expand All @@ -521,6 +576,7 @@ def field_for_schema(
default=marshmallow.missing,
metadata: Mapping[str, Any] = None,
base_schema: Optional[Type[marshmallow.Schema]] = None,
typ_frame: Optional[types.FrameType] = None,
) -> marshmallow.fields.Field:
"""
Get a marshmallow Field corresponding to the given python type.
Expand All @@ -530,6 +586,7 @@ def field_for_schema(
:param default: value to use for (de)serialization when the field is missing
:param metadata: Additional parameters to pass to the marshmallow field constructor
:param base_schema: marshmallow schema used as a base class when deriving dataclass schema
:param typ_frame: frame of type definition
>>> int_field = field_for_schema(int, default=9, metadata=dict(required=True))
>>> int_field.__class__
Expand Down Expand Up @@ -588,10 +645,10 @@ def field_for_schema(
subtyp = type(default)
else:
subtyp = Any
return field_for_schema(subtyp, default, metadata, base_schema)
return field_for_schema(subtyp, default, metadata, base_schema, typ_frame)

# Generic types
generic_field = _field_for_generic_type(typ, base_schema, **metadata)
generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata)
if generic_field:
return generic_field

Expand All @@ -605,6 +662,7 @@ def field_for_schema(
newtype_supertype=newtype_supertype,
metadata=metadata,
base_schema=base_schema,
typ_frame=typ_frame,
)

# enumerations
Expand All @@ -614,12 +672,15 @@ def field_for_schema(
return marshmallow_enum.EnumField(typ, **metadata)

# Nested marshmallow dataclass
# it would be just a class name instead of actual schema util the schema is not ready yet
nested_schema = getattr(typ, "Schema", None)

# Nested dataclasses
forward_reference = getattr(typ, "__forward_arg__", None)
nested = (
nested_schema or forward_reference or _internal_class_schema(typ, base_schema)
nested_schema
or forward_reference
or _internal_class_schema(typ, base_schema, typ_frame)
)

return marshmallow.fields.Nested(nested, **metadata)
Expand Down
42 changes: 42 additions & 0 deletions marshmallow_dataclass/lazy_class_attribute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any, Callable


__all__ = ("lazy_class_attribute",)


class LazyClassAttribute:
"""Descriptor decorator implementing a class-level, read-only
property, which caches its results on the class(es) on which it
operates.
"""

__slots__ = ("func", "name", "called", "forward_value")

def __init__(
self, func: Callable[..., Any], name: str = None, forward_value: Any = None
):
self.func = func
self.name = name
self.called = False
self.forward_value = forward_value

def __get__(self, instance, cls=None):
if not cls:
cls = type(instance)

# avoid recursion
if self.called:
return self.forward_value

self.called = True

setattr(cls, self.name, self.func())

# "getattr" is used to handle bounded methods
return getattr(cls, self.name)

def __set_name__(self, owner, name):
self.name = self.name or name


lazy_class_attribute = LazyClassAttribute
Loading

0 comments on commit fac0b6d

Please sign in to comment.