diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2bf713345..285424729 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,24 @@ Changelog --------- +3.8.0 (Unreleased) +****************** + +Features: + +- The behavior of the ``unknown`` option can be further customized with a new + value, ``PROPAGATE``. If ``unknown=EXCLUDE | PROPAGATE`` is set, then the + value of ``unknown=EXCLUDE | PROPAGATE`` will be passed to any nested schemas. + This works for ``INCLUDE | PROPAGATE`` and ``RAISE | PROPAGATE`` as well. + (:issue:`1490`, :issue:`1428`) + +.. note:: + + When a schema is being loaded with ``unknown=... | PROPAGATE``, this will + override any values set for ``unknown`` in child schemas. Therefore, + ``PROPAGATE`` should only be used in cases in which you want to change + the behavior of an entire schema heirarchy. + 3.7.1 (2020-07-20) ****************** diff --git a/src/marshmallow/__init__.py b/src/marshmallow/__init__.py index 7b0b21b4d..c4617e6ef 100644 --- a/src/marshmallow/__init__.py +++ b/src/marshmallow/__init__.py @@ -9,7 +9,7 @@ validates, validates_schema, ) -from marshmallow.utils import EXCLUDE, INCLUDE, RAISE, pprint, missing +from marshmallow.utils import EXCLUDE, INCLUDE, RAISE, PROPAGATE, pprint, missing from marshmallow.exceptions import ValidationError from distutils.version import LooseVersion @@ -19,6 +19,7 @@ "EXCLUDE", "INCLUDE", "RAISE", + "PROPAGATE", "Schema", "SchemaOpts", "fields", diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index a725c1327..e16065051 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -18,6 +18,7 @@ missing as missing_, resolve_field_instance, is_aware, + UnknownParam, ) from marshmallow.exceptions import ( ValidationError, @@ -473,7 +474,7 @@ def __init__( only: types.StrSequenceOrSet = None, exclude: types.StrSequenceOrSet = (), many: bool = False, - unknown: str = None, + unknown: typing.Union[str, UnknownParam] = None, **kwargs ): # Raise error if only or exclude is passed as string, not list of strings @@ -493,7 +494,7 @@ def __init__( self.only = only self.exclude = exclude self.many = many - self.unknown = unknown + self.unknown = UnknownParam.parse_if_str(unknown) self._schema = None # Cached Schema instance super().__init__(default=default, **kwargs) @@ -571,16 +572,18 @@ def _test_collection(self, value): if many and not utils.is_collection(value): raise self.make_error("type", input=value, type=value.__class__.__name__) - def _load(self, value, data, partial=None): + def _load(self, value, data, partial=None, unknown=None): try: - valid_data = self.schema.load(value, unknown=self.unknown, partial=partial) + valid_data = self.schema.load( + value, unknown=unknown or self.unknown, partial=partial, + ) except ValidationError as error: raise ValidationError( error.messages, valid_data=error.valid_data ) from error return valid_data - def _deserialize(self, value, attr, data, partial=None, **kwargs): + def _deserialize(self, value, attr, data, partial=None, unknown=None, **kwargs): """Same as :meth:`Field._deserialize` with additional ``partial`` argument. :param bool|tuple partial: For nested schemas, the ``partial`` @@ -590,7 +593,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs): Add ``partial`` parameter. """ self._test_collection(value) - return self._load(value, data, partial=partial) + return self._load(value, data, partial=partial, unknown=unknown) class Pluck(Nested): diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 78b386717..a2b2d0f7e 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -27,6 +27,7 @@ RAISE, EXCLUDE, INCLUDE, + UnknownParam, missing, set_value, get_value, @@ -227,7 +228,7 @@ def __init__(self, meta, ordered: bool = False): self.include = getattr(meta, "include", {}) self.load_only = getattr(meta, "load_only", ()) self.dump_only = getattr(meta, "dump_only", ()) - self.unknown = getattr(meta, "unknown", RAISE) + self.unknown = UnknownParam.parse_if_str(getattr(meta, "unknown", RAISE)) self.register = getattr(meta, "register", True) @@ -388,7 +389,7 @@ def __init__( self.load_only = set(load_only) or set(self.opts.load_only) self.dump_only = set(dump_only) or set(self.opts.dump_only) self.partial = partial - self.unknown = unknown or self.opts.unknown + self.unknown = UnknownParam.parse_if_str(unknown) or self.opts.unknown self.context = context or {} self._normalize_nested_options() #: Dictionary mapping field_names -> :class:`Field` objects @@ -609,6 +610,7 @@ def _deserialize( serializing a collection, otherwise `None`. :return: A dictionary of the deserialized data. """ + unknown = UnknownParam.parse_if_str(unknown) index_errors = self.opts.index_errors index = index if index_errors else None if many: @@ -648,7 +650,7 @@ def _deserialize( partial_is_collection and attr_name in partial ): continue - d_kwargs = {} + d_kwargs = {} # type: typing.Dict[str, typing.Any] # Allow partial loading of nested schemas. if partial_is_collection: prefix = field_name + "." @@ -659,6 +661,10 @@ def _deserialize( d_kwargs["partial"] = sub_partial else: d_kwargs["partial"] = partial + + if unknown.propagate: + d_kwargs["unknown"] = unknown + getter = lambda val: field_obj.deserialize( val, field_name, data, **d_kwargs ) @@ -672,16 +678,16 @@ def _deserialize( if value is not missing: key = field_obj.attribute or attr_name set_value(typing.cast(typing.Dict, ret), key, value) - if unknown != EXCLUDE: + if unknown.value != EXCLUDE.value: fields = { field_obj.data_key if field_obj.data_key is not None else field_name for field_name, field_obj in self.load_fields.items() } for key in set(data) - fields: value = data[key] - if unknown == INCLUDE: + if unknown.value == INCLUDE.value: set_value(typing.cast(typing.Dict, ret), key, value) - elif unknown == RAISE: + elif unknown.value == RAISE.value: error_store.store_error( [self.error_messages["unknown"]], key, @@ -721,7 +727,7 @@ def load( if invalid data are passed. """ return self._do_load( - data, many=many, partial=partial, unknown=unknown, postprocess=True + data, many=many, partial=partial, unknown=unknown, postprocess=True, ) def loads( @@ -836,7 +842,7 @@ def _do_load( error_store = ErrorStore() errors = {} # type: typing.Dict[str, typing.List[str]] many = self.many if many is None else bool(many) - unknown = unknown or self.unknown + unknown = UnknownParam.parse_if_str(unknown or self.unknown) if partial is None: partial = self.partial # Run preprocessors diff --git a/src/marshmallow/utils.py b/src/marshmallow/utils.py index 61ed36404..3a99a015f 100644 --- a/src/marshmallow/utils.py +++ b/src/marshmallow/utils.py @@ -1,7 +1,7 @@ """Utility methods for marshmallow.""" import collections -import functools import datetime as dt +import functools import inspect import json import re @@ -15,9 +15,55 @@ from marshmallow.exceptions import FieldInstanceResolutionError from marshmallow.warnings import RemovedInMarshmallow4Warning -EXCLUDE = "exclude" -INCLUDE = "include" -RAISE = "raise" + +class UnknownParam: + good_values = ("exclude", "include", "raise") + + def __init__(self, stringval=None, *, value=None, propagate=None): + self.value = value + self.propagate = propagate + + if stringval: + for x in stringval.lower().split("|"): + x = x.strip() + if x in self.good_values and not self.value: + self.value = x + if x == "propagate": + self.propagate = True + + def __or__(self, other): + return UnknownParam( + value=self.value or other.value, propagate=self.propagate or other.propagate + ) + + def __str__(self): + parts = [self.value] if self.value else [] + if self.propagate: + parts.append("propagate") + if self.value or self.propagate: + return "|".join(parts) + return "null" + + def __repr__(self): + return "UnknownParam(value={!r}, propagate={!r})".format( + self.value, self.propagate + ) + + @classmethod + def parse_if_str(cls, value): + """Given a string or UnknownParam, convert to an UnknownParam + + Preserves None, which is important for making sure that it can be used + blindly on `unknown` which may be a user-supplied value or a default""" + if isinstance(value, str): + return cls(value) + return value + + +EXCLUDE = UnknownParam("exclude") +INCLUDE = UnknownParam("include") +RAISE = UnknownParam("raise") +PROPAGATE = UnknownParam("propagate") class _Missing: @@ -41,8 +87,7 @@ def __repr__(self): def is_generator(obj) -> bool: - """Return True if ``obj`` is a generator - """ + """Return True if ``obj`` is a generator""" return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj) @@ -279,8 +324,7 @@ def set_value(dct: typing.Dict[str, typing.Any], key: str, value: typing.Any): def callable_or_raise(obj): - """Check that an object is callable, else raise a :exc:`ValueError`. - """ + """Check that an object is callable, else raise a :exc:`ValueError`.""" if not callable(obj): raise ValueError("Object {!r} is not callable.".format(obj)) return obj diff --git a/tests/test_fields.py b/tests/test_fields.py index 19abeb6b6..71cdca5b3 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -6,6 +6,7 @@ ValidationError, EXCLUDE, INCLUDE, + PROPAGATE, RAISE, missing, ) @@ -296,6 +297,86 @@ class MySchema(Schema): } +class TestNestedFieldPropagatesUnknown: + class SpamSchema(Schema): + meat = fields.String() + + class CanSchema(Schema): + spam = fields.Nested("SpamSchema") + + class ShelfSchema(Schema): + can = fields.Nested("CanSchema") + + @pytest.fixture + def data_nested_unknown(self): + return {"spam": {"meat": "pork", "add-on": "eggs"}} + + @pytest.fixture + def multi_nested_data_with_unknown(self, data_nested_unknown): + return {"can": data_nested_unknown, "box": {"foo": "bar"}} + + @pytest.mark.parametrize( + "schema_kwargs,load_kwargs", + [ + ({}, {"unknown": INCLUDE | PROPAGATE}), + ({}, {"unknown": "INCLUDE | PROPAGATE"}), + ({"unknown": RAISE}, {"unknown": INCLUDE | PROPAGATE}), + ({"unknown": RAISE}, {"unknown": "include|propagate"}), + ({"unknown": INCLUDE | PROPAGATE}, {}), + ], + ) + def test_propagate_unknown_include( + self, + schema_kwargs, + load_kwargs, + data_nested_unknown, + multi_nested_data_with_unknown, + ): + data = self.ShelfSchema(**schema_kwargs).load( + multi_nested_data_with_unknown, **load_kwargs + ) + assert data == { + "can": {"spam": {"meat": "pork", "add-on": "eggs"}}, + "box": {"foo": "bar"}, + } + + data = self.CanSchema(**schema_kwargs).load(data_nested_unknown, **load_kwargs) + assert data == {"spam": {"meat": "pork", "add-on": "eggs"}} + + @pytest.mark.parametrize( + "schema_kwargs,load_kwargs", + [ + ({}, {"unknown": EXCLUDE | PROPAGATE}), + ({}, {"unknown": "exclude | propagate"}), + ({"unknown": RAISE}, {"unknown": EXCLUDE | PROPAGATE}), + ({"unknown": PROPAGATE | EXCLUDE}, {}), + ({"unknown": "propagate|exclude"}, {}), + ], + ) + def test_propagate_unknown_exclude( + self, + schema_kwargs, + load_kwargs, + data_nested_unknown, + multi_nested_data_with_unknown, + ): + data = self.ShelfSchema(**schema_kwargs).load( + multi_nested_data_with_unknown, **load_kwargs + ) + assert data == {"can": {"spam": {"meat": "pork"}}} + + data = self.CanSchema(**schema_kwargs).load(data_nested_unknown, **load_kwargs) + assert data == {"spam": {"meat": "pork"}} + + @pytest.mark.parametrize("schema_kw", ({}, {"unknown": INCLUDE})) + def test_raises_when_unknown_passed_to_first_level_nested( + self, schema_kw, data_nested_unknown + ): + with pytest.raises(ValidationError) as exc_info: + self.CanSchema(**schema_kw).load(data_nested_unknown) + assert exc_info.value.messages == {"spam": {"add-on": ["Unknown field."]}} + + class TestListNested: @pytest.mark.parametrize("param", ("only", "exclude", "dump_only", "load_only")) def test_list_nested_only_exclude_dump_only_load_only_propagated_to_nested( diff --git a/tests/test_schema.py b/tests/test_schema.py index 85646754b..f8df84b08 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -17,6 +17,7 @@ EXCLUDE, INCLUDE, RAISE, + PROPAGATE, class_registry, ) from marshmallow.exceptions import ( @@ -2901,3 +2902,66 @@ class DefinitelyUniqueSchema(Schema): SchemaClass = class_registry.get_class(DefinitelyUniqueSchema.__name__) assert SchemaClass is DefinitelyUniqueSchema + + +def test_propagate_unknown_overrides_explicit_value_for_nested(): + # PROPAGATE should traverse any schemas and replace them with the + # "unknown" value from the parent context (schema or load arguments) + # this test makes sure that it takes precedence when a nested field + # or schema has "unknown" set explicitly + + class Bottom(Schema): + x = fields.Str() + + class Middle(Schema): + x = fields.Str() + # set unknown explicitly on a nested field, so auto_unknown will be + # false going into Bottom + child = fields.Nested(Bottom, unknown=EXCLUDE) + + class Top(Schema): + x = fields.Str() + child = fields.Nested(Middle) + + data = { + "x": "hi", + "y": "bye", + "child": {"x": "hi", "y": "bye", "child": {"x": "hi", "y": "bye"}}, + } + result = Top(unknown=INCLUDE | PROPAGATE).load(data) + assert result == { + "x": "hi", + "y": "bye", + "child": {"x": "hi", "y": "bye", "child": {"x": "hi", "y": "bye"}}, + } + + +def test_propagate_unknown_overrides_explicit_value_for_meta(): + # this is the same as the above test of unknown propagation, but it checks that + # this applies when `unknown` is set by means of `Meta` as well + + class Bottom(Schema): + x = fields.Str() + + class Middle(Schema): + x = fields.Str() + child = fields.Nested(Bottom) + + class Meta: + unknown = EXCLUDE + + class Top(Schema): + x = fields.Str() + child = fields.Nested(Middle) + + data = { + "x": "hi", + "y": "bye", + "child": {"x": "hi", "y": "bye", "child": {"x": "hi", "y": "bye"}}, + } + result = Top(unknown=INCLUDE | PROPAGATE).load(data) + assert result == { + "x": "hi", + "y": "bye", + "child": {"x": "hi", "y": "bye", "child": {"x": "hi", "y": "bye"}}, + }