Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify Hooks #2279

Merged
merged 2 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion performance/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def run_timeit(quotes, iterations, repeat, profile=False):
profile.disable()
profile.dump_stats("marshmallow.pprof")

usec = best * 1e6 / iterations
usec = best * 1e6 / iterations / len(quotes)
return usec


Expand Down
24 changes: 13 additions & 11 deletions src/marshmallow/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def validate_age(self, data, **kwargs):
from __future__ import annotations

import functools
from collections import defaultdict
from typing import Any, Callable, cast

PRE_DUMP = "pre_dump"
Expand All @@ -79,7 +80,7 @@ def validate_age(self, data, **kwargs):


class MarshmallowHook:
__marshmallow_hook__: dict[tuple[str, bool] | str, Any] | None = None
__marshmallow_hook__: dict[str, list[tuple[bool, Any]]] | None = None


def validates(field_name: str) -> Callable[..., Any]:
Expand Down Expand Up @@ -117,7 +118,8 @@ def validates_schema(
"""
return set_hook(
fn,
(VALIDATES_SCHEMA, pass_many),
VALIDATES_SCHEMA,
many=pass_many,
pass_original=pass_original,
skip_on_field_errors=skip_on_field_errors,
)
Expand All @@ -136,7 +138,7 @@ def pre_dump(
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
return set_hook(fn, (PRE_DUMP, pass_many))
return set_hook(fn, PRE_DUMP, many=pass_many)


def post_dump(
Expand All @@ -157,7 +159,7 @@ def post_dump(
.. versionchanged:: 3.0.0
``many`` is always passed as a keyword arguments to the decorated method.
"""
return set_hook(fn, (POST_DUMP, pass_many), pass_original=pass_original)
return set_hook(fn, POST_DUMP, many=pass_many, pass_original=pass_original)


def pre_load(
Expand All @@ -174,7 +176,7 @@ def pre_load(
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
return set_hook(fn, (PRE_LOAD, pass_many))
return set_hook(fn, PRE_LOAD, many=pass_many)


def post_load(
Expand All @@ -196,11 +198,11 @@ def post_load(
``partial`` and ``many`` are always passed as keyword arguments to
the decorated method.
"""
return set_hook(fn, (POST_LOAD, pass_many), pass_original=pass_original)
return set_hook(fn, POST_LOAD, many=pass_many, pass_original=pass_original)


def set_hook(
fn: Callable[..., Any] | None, key: tuple[str, bool] | str, **kwargs: Any
fn: Callable[..., Any] | None, tag: str, many: bool = False, **kwargs: Any
) -> Callable[..., Any]:
"""Mark decorated function as a hook to be picked up later.
You should not need to use this method directly.
Expand All @@ -214,18 +216,18 @@ def set_hook(
"""
# Allow using this as either a decorator or a decorator factory.
if fn is None:
return functools.partial(set_hook, key=key, **kwargs)
return functools.partial(set_hook, tag=tag, many=many, **kwargs)

# Set a __marshmallow_hook__ attribute instead of wrapping in some class,
# because I still want this to end up as a normal (unbound) method.
function = cast(MarshmallowHook, fn)
try:
hook_config = function.__marshmallow_hook__
except AttributeError:
function.__marshmallow_hook__ = hook_config = {}
function.__marshmallow_hook__ = hook_config = defaultdict(list)
# Also save the kwargs for the tagged function on
# __marshmallow_hook__, keyed by (<tag>, <pass_many>)
# __marshmallow_hook__, keyed by <tag>
if hook_config is not None:
hook_config[key] = kwargs
hook_config[tag].append((many, kwargs))

return fn
44 changes: 20 additions & 24 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ def __init__(cls, name, bases, attrs):
class_registry.register(name, cls)
cls._hooks = cls.resolve_hooks()

def resolve_hooks(cls) -> dict[types.Tag, list[str]]:
def resolve_hooks(cls) -> dict[str, list[tuple[str, bool, dict]]]:
"""Add in the decorated processors

By doing this after constructing the class, we let standard inheritance
do all the hard work.
"""
mro = inspect.getmro(cls)

hooks = defaultdict(list) # type: typing.Dict[types.Tag, typing.List[str]]
hooks = defaultdict(list) # type: typing.Dict[str, typing.List[typing.Tuple[str, bool, dict]]]

for attr_name in dir(cls):
# Need to look up the actual descriptor, not whatever might be
Expand All @@ -176,14 +176,16 @@ def resolve_hooks(cls) -> dict[types.Tag, list[str]]:
continue

try:
hook_config = attr.__marshmallow_hook__
hook_config = attr.__marshmallow_hook__ # type: typing.Dict[str, typing.List[typing.Tuple[bool, dict]]]
except AttributeError:
pass
else:
for key in hook_config.keys():
for tag, config in hook_config.items():
# Use name here so we can get the bound method later, in
# case the processor was a descriptor or something.
hooks[key].append(attr_name)
hooks[tag].extend(
(attr_name, many, kwargs) for many, kwargs in config
)

return hooks

Expand Down Expand Up @@ -319,7 +321,7 @@ class AlbumSchema(Schema):
# These get set by SchemaMeta
opts = None # type: SchemaOpts
_declared_fields = {} # type: typing.Dict[str, ma_fields.Field]
_hooks = {} # type: typing.Dict[types.Tag, typing.List[str]]
_hooks = {} # type: typing.Dict[str, typing.List[typing.Tuple[str, bool, dict]]]

class Meta:
"""Options object for a Schema.
Expand Down Expand Up @@ -539,7 +541,7 @@ def dump(self, obj: typing.Any, *, many: bool | None = None):
Validation no longer occurs upon serialization.
"""
many = self.many if many is None else bool(many)
if self._has_processors(PRE_DUMP):
if self._hooks[PRE_DUMP]:
processed_obj = self._invoke_dump_processors(
PRE_DUMP, obj, many=many, original_data=obj
)
Expand All @@ -548,7 +550,7 @@ def dump(self, obj: typing.Any, *, many: bool | None = None):

result = self._serialize(processed_obj, many=many)

if self._has_processors(POST_DUMP):
if self._hooks[POST_DUMP]:
result = self._invoke_dump_processors(
POST_DUMP, result, many=many, original_data=obj
)
Expand Down Expand Up @@ -846,7 +848,7 @@ def _do_load(
if partial is None:
partial = self.partial
# Run preprocessors
if self._has_processors(PRE_LOAD):
if self._hooks[PRE_LOAD]:
try:
processed_data = self._invoke_load_processors(
PRE_LOAD, data, many=many, original_data=data, partial=partial
Expand All @@ -870,7 +872,7 @@ def _do_load(
error_store=error_store, data=result, many=many
)
# Run schema-level validation
if self._has_processors(VALIDATES_SCHEMA):
if self._hooks[VALIDATES_SCHEMA]:
field_errors = bool(error_store.errors)
self._invoke_schema_validators(
error_store=error_store,
Expand All @@ -892,7 +894,7 @@ def _do_load(
)
errors = error_store.errors
# Run post processors
if not errors and postprocess and self._has_processors(POST_LOAD):
if not errors and postprocess and self._hooks[POST_LOAD]:
try:
result = self._invoke_load_processors(
POST_LOAD,
Expand Down Expand Up @@ -1055,9 +1057,6 @@ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None:
raise error
self.on_bind_field(field_name, field_obj)

def _has_processors(self, tag) -> bool:
return bool(self._hooks[(tag, True)] or self._hooks[(tag, False)])

def _invoke_dump_processors(
self, tag: str, data, *, many: bool, original_data=None
):
Expand Down Expand Up @@ -1102,9 +1101,8 @@ def _invoke_load_processors(
return data

def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool):
for attr_name in self._hooks[VALIDATES]:
for attr_name, _, validator_kwargs in self._hooks[VALIDATES]:
validator = getattr(self, attr_name)
validator_kwargs = validator.__marshmallow_hook__[VALIDATES]
field_name = validator_kwargs["field_name"]

try:
Expand Down Expand Up @@ -1159,11 +1157,10 @@ def _invoke_schema_validators(
partial: bool | types.StrSequenceOrSet | None,
field_errors: bool = False,
):
for attr_name in self._hooks[(VALIDATES_SCHEMA, pass_many)]:
for attr_name, hook_many, validator_kwargs in self._hooks[VALIDATES_SCHEMA]:
if hook_many != pass_many:
continue
validator = getattr(self, attr_name)
validator_kwargs = validator.__marshmallow_hook__[
(VALIDATES_SCHEMA, pass_many)
]
if field_errors and validator_kwargs["skip_on_field_errors"]:
continue
pass_original = validator_kwargs.get("pass_original", False)
Expand Down Expand Up @@ -1201,12 +1198,11 @@ def _invoke_processors(
original_data=None,
**kwargs,
):
key = (tag, pass_many)
for attr_name in self._hooks[key]:
for attr_name, hook_many, processor_kwargs in self._hooks[tag]:
if hook_many != pass_many:
continue
# This will be a bound method.
processor = getattr(self, attr_name)

processor_kwargs = processor.__marshmallow_hook__[key]
pass_original = processor_kwargs.get("pass_original", False)

if many and not pass_many:
Expand Down
1 change: 0 additions & 1 deletion src/marshmallow/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,4 @@
import typing

StrSequenceOrSet = typing.Union[typing.Sequence[str], typing.AbstractSet[str]]
Tag = typing.Union[str, typing.Tuple[str, bool]]
Validator = typing.Callable[[typing.Any], typing.Any]