Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate unknown when an explicit flag is set, in the form unknown=...|PROPAGATE #1634

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,24 @@
Changelog
---------

3.8.0 (Unreleased)
******************

Features:

- The behavior of the ``unknown`` option can be further customized with a new
value, ``PROPAGATE``. If ``unknown=EXCLUDE | PROPAGATE`` is set, then the
value of ``unknown=EXCLUDE | PROPAGATE`` will be passed to any nested schemas.
This works for ``INCLUDE | PROPAGATE`` and ``RAISE | PROPAGATE`` as well.
(:issue:`1490`, :issue:`1428`)

.. note::

When a schema is being loaded with ``unknown=... | PROPAGATE``, this will
override any values set for ``unknown`` in child schemas. Therefore,
``PROPAGATE`` should only be used in cases in which you want to change
the behavior of an entire schema heirarchy.

3.7.1 (2020-07-20)
******************

Expand Down
3 changes: 2 additions & 1 deletion src/marshmallow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
validates,
validates_schema,
)
from marshmallow.utils import EXCLUDE, INCLUDE, RAISE, pprint, missing
from marshmallow.utils import EXCLUDE, INCLUDE, RAISE, PROPAGATE, pprint, missing
from marshmallow.exceptions import ValidationError
from distutils.version import LooseVersion

Expand All @@ -19,6 +19,7 @@
"EXCLUDE",
"INCLUDE",
"RAISE",
"PROPAGATE",
"Schema",
"SchemaOpts",
"fields",
Expand Down
15 changes: 9 additions & 6 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
missing as missing_,
resolve_field_instance,
is_aware,
UnknownParam,
)
from marshmallow.exceptions import (
ValidationError,
Expand Down Expand Up @@ -473,7 +474,7 @@ def __init__(
only: types.StrSequenceOrSet = None,
exclude: types.StrSequenceOrSet = (),
many: bool = False,
unknown: str = None,
unknown: typing.Union[str, UnknownParam] = None,
**kwargs
):
# Raise error if only or exclude is passed as string, not list of strings
Expand All @@ -493,7 +494,7 @@ def __init__(
self.only = only
self.exclude = exclude
self.many = many
self.unknown = unknown
self.unknown = UnknownParam.parse_if_str(unknown)
self._schema = None # Cached Schema instance
super().__init__(default=default, **kwargs)

Expand Down Expand Up @@ -571,16 +572,18 @@ def _test_collection(self, value):
if many and not utils.is_collection(value):
raise self.make_error("type", input=value, type=value.__class__.__name__)

def _load(self, value, data, partial=None):
def _load(self, value, data, partial=None, unknown=None):
try:
valid_data = self.schema.load(value, unknown=self.unknown, partial=partial)
valid_data = self.schema.load(
value, unknown=unknown or self.unknown, partial=partial,
)
except ValidationError as error:
raise ValidationError(
error.messages, valid_data=error.valid_data
) from error
return valid_data

def _deserialize(self, value, attr, data, partial=None, **kwargs):
def _deserialize(self, value, attr, data, partial=None, unknown=None, **kwargs):
"""Same as :meth:`Field._deserialize` with additional ``partial`` argument.

:param bool|tuple partial: For nested schemas, the ``partial``
Expand All @@ -590,7 +593,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs):
Add ``partial`` parameter.
"""
self._test_collection(value)
return self._load(value, data, partial=partial)
return self._load(value, data, partial=partial, unknown=unknown)


class Pluck(Nested):
Expand Down
22 changes: 14 additions & 8 deletions src/marshmallow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
RAISE,
EXCLUDE,
INCLUDE,
UnknownParam,
missing,
set_value,
get_value,
Expand Down Expand Up @@ -227,7 +228,7 @@ def __init__(self, meta, ordered: bool = False):
self.include = getattr(meta, "include", {})
self.load_only = getattr(meta, "load_only", ())
self.dump_only = getattr(meta, "dump_only", ())
self.unknown = getattr(meta, "unknown", RAISE)
self.unknown = UnknownParam.parse_if_str(getattr(meta, "unknown", RAISE))
self.register = getattr(meta, "register", True)


Expand Down Expand Up @@ -388,7 +389,7 @@ def __init__(
self.load_only = set(load_only) or set(self.opts.load_only)
self.dump_only = set(dump_only) or set(self.opts.dump_only)
self.partial = partial
self.unknown = unknown or self.opts.unknown
self.unknown = UnknownParam.parse_if_str(unknown) or self.opts.unknown
self.context = context or {}
self._normalize_nested_options()
#: Dictionary mapping field_names -> :class:`Field` objects
Expand Down Expand Up @@ -609,6 +610,7 @@ def _deserialize(
serializing a collection, otherwise `None`.
:return: A dictionary of the deserialized data.
"""
unknown = UnknownParam.parse_if_str(unknown)
index_errors = self.opts.index_errors
index = index if index_errors else None
if many:
Expand Down Expand Up @@ -648,7 +650,7 @@ def _deserialize(
partial_is_collection and attr_name in partial
):
continue
d_kwargs = {}
d_kwargs = {} # type: typing.Dict[str, typing.Any]
# Allow partial loading of nested schemas.
if partial_is_collection:
prefix = field_name + "."
Expand All @@ -659,6 +661,10 @@ def _deserialize(
d_kwargs["partial"] = sub_partial
else:
d_kwargs["partial"] = partial

if unknown.propagate:
d_kwargs["unknown"] = unknown

getter = lambda val: field_obj.deserialize(
val, field_name, data, **d_kwargs
)
Expand All @@ -672,16 +678,16 @@ def _deserialize(
if value is not missing:
key = field_obj.attribute or attr_name
set_value(typing.cast(typing.Dict, ret), key, value)
if unknown != EXCLUDE:
if unknown.value != EXCLUDE.value:
fields = {
field_obj.data_key if field_obj.data_key is not None else field_name
for field_name, field_obj in self.load_fields.items()
}
for key in set(data) - fields:
value = data[key]
if unknown == INCLUDE:
if unknown.value == INCLUDE.value:
set_value(typing.cast(typing.Dict, ret), key, value)
elif unknown == RAISE:
elif unknown.value == RAISE.value:
error_store.store_error(
[self.error_messages["unknown"]],
key,
Expand Down Expand Up @@ -721,7 +727,7 @@ def load(
if invalid data are passed.
"""
return self._do_load(
data, many=many, partial=partial, unknown=unknown, postprocess=True
data, many=many, partial=partial, unknown=unknown, postprocess=True,
)

def loads(
Expand Down Expand Up @@ -836,7 +842,7 @@ def _do_load(
error_store = ErrorStore()
errors = {} # type: typing.Dict[str, typing.List[str]]
many = self.many if many is None else bool(many)
unknown = unknown or self.unknown
unknown = UnknownParam.parse_if_str(unknown or self.unknown)
if partial is None:
partial = self.partial
# Run preprocessors
Expand Down
60 changes: 52 additions & 8 deletions src/marshmallow/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Utility methods for marshmallow."""
import collections
import functools
import datetime as dt
import functools
import inspect
import json
import re
Expand All @@ -15,9 +15,55 @@
from marshmallow.exceptions import FieldInstanceResolutionError
from marshmallow.warnings import RemovedInMarshmallow4Warning

EXCLUDE = "exclude"
INCLUDE = "include"
RAISE = "raise"

class UnknownParam:
good_values = ("exclude", "include", "raise")

def __init__(self, stringval=None, *, value=None, propagate=None):
self.value = value
self.propagate = propagate

if stringval:
for x in stringval.lower().split("|"):
x = x.strip()
if x in self.good_values and not self.value:
self.value = x
if x == "propagate":
self.propagate = True

def __or__(self, other):
return UnknownParam(
value=self.value or other.value, propagate=self.propagate or other.propagate
)

def __str__(self):
parts = [self.value] if self.value else []
if self.propagate:
parts.append("propagate")
if self.value or self.propagate:
return "|".join(parts)
return "null"

def __repr__(self):
return "UnknownParam(value={!r}, propagate={!r})".format(
self.value, self.propagate
)

@classmethod
def parse_if_str(cls, value):
"""Given a string or UnknownParam, convert to an UnknownParam

Preserves None, which is important for making sure that it can be used
blindly on `unknown` which may be a user-supplied value or a default"""
if isinstance(value, str):
return cls(value)
return value


EXCLUDE = UnknownParam("exclude")
INCLUDE = UnknownParam("include")
RAISE = UnknownParam("raise")
PROPAGATE = UnknownParam("propagate")


class _Missing:
Expand All @@ -41,8 +87,7 @@ def __repr__(self):


def is_generator(obj) -> bool:
"""Return True if ``obj`` is a generator
"""
"""Return True if ``obj`` is a generator"""
return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)


Expand Down Expand Up @@ -279,8 +324,7 @@ def set_value(dct: typing.Dict[str, typing.Any], key: str, value: typing.Any):


def callable_or_raise(obj):
"""Check that an object is callable, else raise a :exc:`ValueError`.
"""
"""Check that an object is callable, else raise a :exc:`ValueError`."""
if not callable(obj):
raise ValueError("Object {!r} is not callable.".format(obj))
return obj
Expand Down
81 changes: 81 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ValidationError,
EXCLUDE,
INCLUDE,
PROPAGATE,
RAISE,
missing,
)
Expand Down Expand Up @@ -296,6 +297,86 @@ class MySchema(Schema):
}


class TestNestedFieldPropagatesUnknown:
class SpamSchema(Schema):
meat = fields.String()

class CanSchema(Schema):
spam = fields.Nested("SpamSchema")

class ShelfSchema(Schema):
can = fields.Nested("CanSchema")

@pytest.fixture
def data_nested_unknown(self):
return {"spam": {"meat": "pork", "add-on": "eggs"}}

@pytest.fixture
def multi_nested_data_with_unknown(self, data_nested_unknown):
return {"can": data_nested_unknown, "box": {"foo": "bar"}}

@pytest.mark.parametrize(
"schema_kwargs,load_kwargs",
[
({}, {"unknown": INCLUDE | PROPAGATE}),
({}, {"unknown": "INCLUDE | PROPAGATE"}),
({"unknown": RAISE}, {"unknown": INCLUDE | PROPAGATE}),
({"unknown": RAISE}, {"unknown": "include|propagate"}),
({"unknown": INCLUDE | PROPAGATE}, {}),
],
)
def test_propagate_unknown_include(
self,
schema_kwargs,
load_kwargs,
data_nested_unknown,
multi_nested_data_with_unknown,
):
data = self.ShelfSchema(**schema_kwargs).load(
multi_nested_data_with_unknown, **load_kwargs
)
assert data == {
"can": {"spam": {"meat": "pork", "add-on": "eggs"}},
"box": {"foo": "bar"},
}

data = self.CanSchema(**schema_kwargs).load(data_nested_unknown, **load_kwargs)
assert data == {"spam": {"meat": "pork", "add-on": "eggs"}}

@pytest.mark.parametrize(
"schema_kwargs,load_kwargs",
[
({}, {"unknown": EXCLUDE | PROPAGATE}),
({}, {"unknown": "exclude | propagate"}),
({"unknown": RAISE}, {"unknown": EXCLUDE | PROPAGATE}),
({"unknown": PROPAGATE | EXCLUDE}, {}),
({"unknown": "propagate|exclude"}, {}),
],
)
def test_propagate_unknown_exclude(
self,
schema_kwargs,
load_kwargs,
data_nested_unknown,
multi_nested_data_with_unknown,
):
data = self.ShelfSchema(**schema_kwargs).load(
multi_nested_data_with_unknown, **load_kwargs
)
assert data == {"can": {"spam": {"meat": "pork"}}}

data = self.CanSchema(**schema_kwargs).load(data_nested_unknown, **load_kwargs)
assert data == {"spam": {"meat": "pork"}}

@pytest.mark.parametrize("schema_kw", ({}, {"unknown": INCLUDE}))
def test_raises_when_unknown_passed_to_first_level_nested(
self, schema_kw, data_nested_unknown
):
with pytest.raises(ValidationError) as exc_info:
self.CanSchema(**schema_kw).load(data_nested_unknown)
assert exc_info.value.messages == {"spam": {"add-on": ["Unknown field."]}}


class TestListNested:
@pytest.mark.parametrize("param", ("only", "exclude", "dump_only", "load_only"))
def test_list_nested_only_exclude_dump_only_load_only_propagated_to_nested(
Expand Down
Loading