From 747f3644379a5a58c361fb6ab59ebea2e28946d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Wed, 27 Oct 2021 17:00:40 +0200 Subject: [PATCH 1/7] Let data_key default to field name --- src/marshmallow/fields.py | 3 ++- src/marshmallow/schema.py | 21 +++++---------------- tests/test_fields.py | 9 +++++++++ 3 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 9eaf812f5..f1a1b4ef5 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -377,6 +377,7 @@ def _bind_to_schema(self, field_name, schema): """ self.parent = self.parent or schema self.name = self.name or field_name + self.data_key = self.data_key if self.data_key is not None else field_name self.root = self.root or ( self.parent.root if isinstance(self.parent, FieldABC) else self.parent ) @@ -688,7 +689,7 @@ def __init__( @property def _field_data_key(self): only_field = self.schema.fields[self.field_name] - return only_field.data_key or self.field_name + return only_field.data_key def _serialize(self, nested_obj, attr, obj, **kwargs): ret = super()._serialize(nested_obj, attr, obj, **kwargs) diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 12bb2f4c7..5a95b987e 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -520,8 +520,7 @@ def _serialize( value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute) if value is missing: continue - key = field_obj.data_key if field_obj.data_key is not None else attr_name - ret[key] = value + ret[field_obj.data_key] = value return ret def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None): @@ -634,9 +633,7 @@ def _deserialize( else: partial_is_collection = is_collection(partial) for attr_name, field_obj in self.load_fields.items(): - field_name = ( - field_obj.data_key if field_obj.data_key is not None else attr_name - ) + field_name = field_obj.data_key raw_value = data.get(field_name, missing) if raw_value is missing: # Ignore missing field if we're allowed to. @@ -669,10 +666,7 @@ def _deserialize( key = field_obj.attribute or attr_name set_value(ret_d, key, value) if unknown != EXCLUDE: - fields = { - field_obj.data_key if field_obj.data_key is not None else field_name - for field_name, field_obj in self.load_fields.items() - } + fields = {field_obj.data_key for field_obj in self.load_fields.values()} for key in set(data) - fields: value = data[key] if unknown == INCLUDE: @@ -986,10 +980,7 @@ def _init_fields(self) -> None: if not field_obj.load_only: dump_fields[field_name] = field_obj - dump_data_keys = [ - field_obj.data_key if field_obj.data_key is not None else name - for name, field_obj in dump_fields.items() - ] + dump_data_keys = [field_obj.data_key for field_obj in dump_fields.values()] if len(dump_data_keys) != len(set(dump_data_keys)): data_keys_duplicates = { x for x in dump_data_keys if dump_data_keys.count(x) > 1 @@ -1110,9 +1101,7 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool) continue raise ValueError(f'"{field_name}" field does not exist.') from error - data_key = ( - field_obj.data_key if field_obj.data_key is not None else field_name - ) + data_key = field_obj.data_key if many: for idx, item in enumerate(data): try: diff --git a/tests/test_fields.py b/tests/test_fields.py index e7f552bd1..443dc8d6c 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -92,6 +92,15 @@ class MySchema(Schema): result = MySchema().dump({"name": "Monty", "foo": 42}) assert result == {"_NaMe": "Monty"} + def test_data_key_defaults_to_field_name(self): + class MySchema(Schema): + field_1 = fields.String(data_key="field_one") + field_2 = fields.String() + + schema_fields = MySchema().fields + assert schema_fields["field_1"].data_key == "field_one" + assert schema_fields["field_2"].data_key == "field_2" + class TestParentAndName: class MySchema(Schema): From f61370c7e2f04256d1f3ee4967d79497dcd58823 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 28 Oct 2021 00:25:07 +0200 Subject: [PATCH 2/7] typing.cast field_obj.data_key --- src/marshmallow/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 5a95b987e..f474b2122 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -633,7 +633,7 @@ def _deserialize( else: partial_is_collection = is_collection(partial) for attr_name, field_obj in self.load_fields.items(): - field_name = field_obj.data_key + field_name = typing.cast(str, field_obj.data_key) raw_value = data.get(field_name, missing) if raw_value is missing: # Ignore missing field if we're allowed to. From 43b0bc796e48f25d030b26bafd2892189e6c5342 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 28 Oct 2021 08:40:34 +0200 Subject: [PATCH 3/7] Add data_key to Field.__repr__ --- src/marshmallow/fields.py | 2 +- tests/test_fields.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index f1a1b4ef5..75936c50e 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -235,7 +235,7 @@ def __init__( def __repr__(self) -> str: return ( " Date: Thu, 28 Oct 2021 08:48:45 +0200 Subject: [PATCH 4/7] Let attribute default to field name --- src/marshmallow/fields.py | 1 + src/marshmallow/schema.py | 8 ++++---- tests/test_fields.py | 9 +++++++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 75936c50e..93e148ffd 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -378,6 +378,7 @@ def _bind_to_schema(self, field_name, schema): self.parent = self.parent or schema self.name = self.name or field_name self.data_key = self.data_key if self.data_key is not None else field_name + self.attribute = self.attribute if self.attribute is not None else field_name self.root = self.root or ( self.parent.root if isinstance(self.parent, FieldABC) else self.parent ) diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index f474b2122..a8dc6ce04 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -663,7 +663,7 @@ def _deserialize( index=index, ) if value is not missing: - key = field_obj.attribute or attr_name + key = field_obj.attribute set_value(ret_d, key, value) if unknown != EXCLUDE: fields = {field_obj.data_key for field_obj in self.load_fields.values()} @@ -991,7 +991,7 @@ def _init_fields(self) -> None: "Check the following field names and " "data_key arguments: {}".format(list(data_keys_duplicates)) ) - load_attributes = [obj.attribute or name for name, obj in load_fields.items()] + load_attributes = [obj.attribute for obj in load_fields.values()] if len(load_attributes) != len(set(load_attributes)): attributes_duplicates = { x for x in load_attributes if load_attributes.count(x) > 1 @@ -1105,7 +1105,7 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool) if many: for idx, item in enumerate(data): try: - value = item[field_obj.attribute or field_name] + value = item[field_obj.attribute] except KeyError: pass else: @@ -1120,7 +1120,7 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool) data[idx].pop(field_name, None) else: try: - value = data[field_obj.attribute or field_name] + value = data[field_obj.attribute] except KeyError: pass else: diff --git a/tests/test_fields.py b/tests/test_fields.py index 4dd30006b..9112c1e1a 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -101,6 +101,15 @@ class MySchema(Schema): assert schema_fields["field_1"].data_key == "field_one" assert schema_fields["field_2"].data_key == "field_2" + def test_attribute_defaults_to_field_name(self): + class MySchema(Schema): + field_1 = fields.String(attribute="field_one") + field_2 = fields.String() + + schema_fields = MySchema().fields + assert schema_fields["field_1"].attribute == "field_one" + assert schema_fields["field_2"].attribute == "field_2" + class TestParentAndName: class MySchema(Schema): From a53475a3fe4cd089606f8e0d7c06b21fd9a77cc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 28 Oct 2021 08:49:30 +0200 Subject: [PATCH 5/7] typing.cast field_obj.attribute --- src/marshmallow/schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index a8dc6ce04..13653baf7 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -663,7 +663,7 @@ def _deserialize( index=index, ) if value is not missing: - key = field_obj.attribute + key = typing.cast(str, field_obj.attribute) set_value(ret_d, key, value) if unknown != EXCLUDE: fields = {field_obj.data_key for field_obj in self.load_fields.values()} From 0d12e42d631b9f2688f822fb18513fe8bee4a8e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 28 Oct 2021 09:00:26 +0200 Subject: [PATCH 6/7] Rename local variables in Schema methods --- src/marshmallow/schema.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 13653baf7..d223ed073 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -633,8 +633,8 @@ def _deserialize( else: partial_is_collection = is_collection(partial) for attr_name, field_obj in self.load_fields.items(): - field_name = typing.cast(str, field_obj.data_key) - raw_value = data.get(field_name, missing) + data_key = typing.cast(str, field_obj.data_key) + raw_value = data.get(data_key, missing) if raw_value is missing: # Ignore missing field if we're allowed to. if partial is True or ( @@ -644,7 +644,7 @@ def _deserialize( d_kwargs = {} # Allow partial loading of nested schemas. if partial_is_collection: - prefix = field_name + "." + prefix = data_key + "." len_prefix = len(prefix) sub_partial = [ f[len_prefix:] for f in partial if f.startswith(prefix) @@ -653,18 +653,18 @@ def _deserialize( else: d_kwargs["partial"] = partial getter = lambda val: field_obj.deserialize( - val, field_name, data, **d_kwargs + val, data_key, data, **d_kwargs ) value = self._call_and_store( getter_func=getter, data=raw_value, - field_name=field_name, + field_name=data_key, error_store=error_store, index=index, ) if value is not missing: - key = typing.cast(str, field_obj.attribute) - set_value(ret_d, key, value) + attribute = typing.cast(str, field_obj.attribute) + set_value(ret_d, attribute, value) if unknown != EXCLUDE: fields = {field_obj.data_key for field_obj in self.load_fields.values()} for key in set(data) - fields: From a300e1ca72c0bdf7235e23a0ca3edae98aa1d9cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Thu, 28 Oct 2021 10:34:16 +0200 Subject: [PATCH 7/7] Set field names on Schema creation --- src/marshmallow/fields.py | 28 +++++++++++++++++++++----- src/marshmallow/schema.py | 42 +++++++++++++++++++++++++-------------- tests/test_schema.py | 7 +++---- 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 93e148ffd..9959aa885 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -368,6 +368,11 @@ def deserialize( # Methods for concrete classes to override. + def _set_name(self, field_name): + self.name = self.name or field_name + self.data_key = self.data_key if self.data_key is not None else field_name + self.attribute = self.attribute if self.attribute is not None else field_name + def _bind_to_schema(self, field_name, schema): """Update field with values from its parent schema. Called by :meth:`Schema._bind_field `. @@ -376,9 +381,6 @@ def _bind_to_schema(self, field_name, schema): :param Schema|Field schema: Parent object. """ self.parent = self.parent or schema - self.name = self.name or field_name - self.data_key = self.data_key if self.data_key is not None else field_name - self.attribute = self.attribute if self.attribute is not None else field_name self.root = self.root or ( self.parent.root if isinstance(self.parent, FieldABC) else self.parent ) @@ -744,6 +746,10 @@ def __init__(self, cls_or_instance: typing.Union[Field, type], **kwargs): self.only = self.inner.only self.exclude = self.inner.exclude + def _set_name(self, field_name): + super()._set_name(field_name) + self.inner._set_name(field_name) + def _bind_to_schema(self, field_name, schema): super()._bind_to_schema(field_name, schema) self.inner = copy.deepcopy(self.inner) @@ -820,6 +826,11 @@ def __init__(self, tuple_fields, *args, **kwargs): self.validate_length = Length(equal=len(self.tuple_fields)) + def _set_name(self, field_name): + super()._set_name(field_name) + for field in self.tuple_fields: + field._set_name(field_name) + def _bind_to_schema(self, field_name, schema): super()._bind_to_schema(field_name, schema) new_tuple_fields = [] @@ -1543,6 +1554,13 @@ def __init__( self.only = self.value_field.only self.exclude = self.value_field.exclude + def _set_name(self, field_name): + super()._set_name(field_name) + if self.value_field: + self.value_field._set_name(field_name) + if self.key_field: + self.key_field._set_name(field_name) + def _bind_to_schema(self, field_name, schema): super()._bind_to_schema(field_name, schema) if self.value_field: @@ -1975,8 +1993,8 @@ class Inferred(Field): Users should not need to use this class directly. """ - def __init__(self): - super().__init__() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) # We memoize the fields to avoid creating and binding new fields # every time on serialization. self._field_cache = {} diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index d223ed073..e3bb4ca6b 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -38,6 +38,23 @@ _T = typing.TypeVar("_T") +def _set_field_name(field_obj, field_name): + try: + field_obj._set_name(field_name) + except TypeError as error: + # Field declared as a class, not an instance. Ignore type checking because + # we handle unsupported arg types, i.e. this is dead code from + # the type checker's perspective. + if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC): + msg = ( + 'Field for "{}" must be declared as a ' + "Field instance, not a class. " + 'Did you mean "fields.{}()"?'.format(field_name, field_obj.__name__) + ) + raise TypeError(msg) from error + raise error + + def _get_fields(attrs, ordered=False): """Get fields from a class. If ordered=True, fields will sorted by creation index. @@ -51,6 +68,9 @@ def _get_fields(attrs, ordered=False): ] if ordered: fields.sort(key=lambda pair: pair[1]._creation_index) + # Set field name on each field + for field_name, field_value in fields: + _set_field_name(field_value, field_name) return fields @@ -111,6 +131,8 @@ def __new__(mcs, name, bases, attrs): # get_declared_fields klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered) # Add fields specified in the `include` class Meta option + for field_name, field_obj in klass.opts.include.items(): + _set_field_name(field_obj, field_name) cls_fields += list(klass.opts.include.items()) dict_cls = OrderedDict if ordered else dict @@ -969,7 +991,10 @@ def _init_fields(self) -> None: fields_dict = self.dict_class() for field_name in field_names: - field_obj = self.declared_fields.get(field_name, ma_fields.Inferred()) + field_obj = self.declared_fields.get( + field_name, + ma_fields.Inferred(attribute=field_name, data_key=field_name), + ) self._bind_field(field_name, field_obj) fields_dict[field_name] = field_obj @@ -1025,20 +1050,7 @@ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None: field_obj.load_only = True if field_name in self.dump_only: field_obj.dump_only = True - try: - field_obj._bind_to_schema(field_name, self) - except TypeError as error: - # Field declared as a class, not an instance. Ignore type checking because - # we handle unsupported arg types, i.e. this is dead code from - # the type checker's perspective. - if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC): - msg = ( - 'Field for "{}" must be declared as a ' - "Field instance, not a class. " - 'Did you mean "fields.{}()"?'.format(field_name, field_obj.__name__) - ) - raise TypeError(msg) from error - raise error + field_obj._bind_to_schema(field_name, self) self.on_bind_field(field_name, field_obj) @lru_cache(maxsize=8) diff --git a/tests/test_schema.py b/tests/test_schema.py index a289dbdf8..f3684eb52 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -553,11 +553,10 @@ def test_function_field(serialized_user, user): def test_fields_must_be_declared_as_instances(user): - class BadUserSchema(Schema): - name = fields.String - with pytest.raises(TypeError, match="must be declared as a Field instance"): - BadUserSchema().dump(user) + + class BadUserSchema(Schema): + name = fields.String # regression test