diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 9eaf812f5..9959aa885 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -235,7 +235,7 @@ def __init__( def __repr__(self) -> str: return ( "`. @@ -376,7 +381,6 @@ def _bind_to_schema(self, field_name, schema): :param Schema|Field schema: Parent object. """ self.parent = self.parent or schema - self.name = self.name or field_name self.root = self.root or ( self.parent.root if isinstance(self.parent, FieldABC) else self.parent ) @@ -688,7 +692,7 @@ def __init__( @property def _field_data_key(self): only_field = self.schema.fields[self.field_name] - return only_field.data_key or self.field_name + return only_field.data_key def _serialize(self, nested_obj, attr, obj, **kwargs): ret = super()._serialize(nested_obj, attr, obj, **kwargs) @@ -742,6 +746,10 @@ def __init__(self, cls_or_instance: typing.Union[Field, type], **kwargs): self.only = self.inner.only self.exclude = self.inner.exclude + def _set_name(self, field_name): + super()._set_name(field_name) + self.inner._set_name(field_name) + def _bind_to_schema(self, field_name, schema): super()._bind_to_schema(field_name, schema) self.inner = copy.deepcopy(self.inner) @@ -818,6 +826,11 @@ def __init__(self, tuple_fields, *args, **kwargs): self.validate_length = Length(equal=len(self.tuple_fields)) + def _set_name(self, field_name): + super()._set_name(field_name) + for field in self.tuple_fields: + field._set_name(field_name) + def _bind_to_schema(self, field_name, schema): super()._bind_to_schema(field_name, schema) new_tuple_fields = [] @@ -1541,6 +1554,13 @@ def __init__( self.only = self.value_field.only self.exclude = self.value_field.exclude + def _set_name(self, field_name): + super()._set_name(field_name) + if self.value_field: + self.value_field._set_name(field_name) + if self.key_field: + self.key_field._set_name(field_name) + def _bind_to_schema(self, field_name, schema): super()._bind_to_schema(field_name, schema) if self.value_field: @@ -1973,8 +1993,8 @@ class Inferred(Field): Users should not need to use this class directly. """ - def __init__(self): - super().__init__() + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) # We memoize the fields to avoid creating and binding new fields # every time on serialization. self._field_cache = {} diff --git a/src/marshmallow/schema.py b/src/marshmallow/schema.py index 12bb2f4c7..e3bb4ca6b 100644 --- a/src/marshmallow/schema.py +++ b/src/marshmallow/schema.py @@ -38,6 +38,23 @@ _T = typing.TypeVar("_T") +def _set_field_name(field_obj, field_name): + try: + field_obj._set_name(field_name) + except TypeError as error: + # Field declared as a class, not an instance. Ignore type checking because + # we handle unsupported arg types, i.e. this is dead code from + # the type checker's perspective. + if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC): + msg = ( + 'Field for "{}" must be declared as a ' + "Field instance, not a class. " + 'Did you mean "fields.{}()"?'.format(field_name, field_obj.__name__) + ) + raise TypeError(msg) from error + raise error + + def _get_fields(attrs, ordered=False): """Get fields from a class. If ordered=True, fields will sorted by creation index. @@ -51,6 +68,9 @@ def _get_fields(attrs, ordered=False): ] if ordered: fields.sort(key=lambda pair: pair[1]._creation_index) + # Set field name on each field + for field_name, field_value in fields: + _set_field_name(field_value, field_name) return fields @@ -111,6 +131,8 @@ def __new__(mcs, name, bases, attrs): # get_declared_fields klass.opts = klass.OPTIONS_CLASS(meta, ordered=ordered) # Add fields specified in the `include` class Meta option + for field_name, field_obj in klass.opts.include.items(): + _set_field_name(field_obj, field_name) cls_fields += list(klass.opts.include.items()) dict_cls = OrderedDict if ordered else dict @@ -520,8 +542,7 @@ def _serialize( value = field_obj.serialize(attr_name, obj, accessor=self.get_attribute) if value is missing: continue - key = field_obj.data_key if field_obj.data_key is not None else attr_name - ret[key] = value + ret[field_obj.data_key] = value return ret def dump(self, obj: typing.Any, *, many: typing.Optional[bool] = None): @@ -634,10 +655,8 @@ def _deserialize( else: partial_is_collection = is_collection(partial) for attr_name, field_obj in self.load_fields.items(): - field_name = ( - field_obj.data_key if field_obj.data_key is not None else attr_name - ) - raw_value = data.get(field_name, missing) + data_key = typing.cast(str, field_obj.data_key) + raw_value = data.get(data_key, missing) if raw_value is missing: # Ignore missing field if we're allowed to. if partial is True or ( @@ -647,7 +666,7 @@ def _deserialize( d_kwargs = {} # Allow partial loading of nested schemas. if partial_is_collection: - prefix = field_name + "." + prefix = data_key + "." len_prefix = len(prefix) sub_partial = [ f[len_prefix:] for f in partial if f.startswith(prefix) @@ -656,23 +675,20 @@ def _deserialize( else: d_kwargs["partial"] = partial getter = lambda val: field_obj.deserialize( - val, field_name, data, **d_kwargs + val, data_key, data, **d_kwargs ) value = self._call_and_store( getter_func=getter, data=raw_value, - field_name=field_name, + field_name=data_key, error_store=error_store, index=index, ) if value is not missing: - key = field_obj.attribute or attr_name - set_value(ret_d, key, value) + attribute = typing.cast(str, field_obj.attribute) + set_value(ret_d, attribute, value) if unknown != EXCLUDE: - fields = { - field_obj.data_key if field_obj.data_key is not None else field_name - for field_name, field_obj in self.load_fields.items() - } + fields = {field_obj.data_key for field_obj in self.load_fields.values()} for key in set(data) - fields: value = data[key] if unknown == INCLUDE: @@ -975,7 +991,10 @@ def _init_fields(self) -> None: fields_dict = self.dict_class() for field_name in field_names: - field_obj = self.declared_fields.get(field_name, ma_fields.Inferred()) + field_obj = self.declared_fields.get( + field_name, + ma_fields.Inferred(attribute=field_name, data_key=field_name), + ) self._bind_field(field_name, field_obj) fields_dict[field_name] = field_obj @@ -986,10 +1005,7 @@ def _init_fields(self) -> None: if not field_obj.load_only: dump_fields[field_name] = field_obj - dump_data_keys = [ - field_obj.data_key if field_obj.data_key is not None else name - for name, field_obj in dump_fields.items() - ] + dump_data_keys = [field_obj.data_key for field_obj in dump_fields.values()] if len(dump_data_keys) != len(set(dump_data_keys)): data_keys_duplicates = { x for x in dump_data_keys if dump_data_keys.count(x) > 1 @@ -1000,7 +1016,7 @@ def _init_fields(self) -> None: "Check the following field names and " "data_key arguments: {}".format(list(data_keys_duplicates)) ) - load_attributes = [obj.attribute or name for name, obj in load_fields.items()] + load_attributes = [obj.attribute for obj in load_fields.values()] if len(load_attributes) != len(set(load_attributes)): attributes_duplicates = { x for x in load_attributes if load_attributes.count(x) > 1 @@ -1034,20 +1050,7 @@ def _bind_field(self, field_name: str, field_obj: ma_fields.Field) -> None: field_obj.load_only = True if field_name in self.dump_only: field_obj.dump_only = True - try: - field_obj._bind_to_schema(field_name, self) - except TypeError as error: - # Field declared as a class, not an instance. Ignore type checking because - # we handle unsupported arg types, i.e. this is dead code from - # the type checker's perspective. - if isinstance(field_obj, type) and issubclass(field_obj, base.FieldABC): - msg = ( - 'Field for "{}" must be declared as a ' - "Field instance, not a class. " - 'Did you mean "fields.{}()"?'.format(field_name, field_obj.__name__) - ) - raise TypeError(msg) from error - raise error + field_obj._bind_to_schema(field_name, self) self.on_bind_field(field_name, field_obj) @lru_cache(maxsize=8) @@ -1110,13 +1113,11 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool) continue raise ValueError(f'"{field_name}" field does not exist.') from error - data_key = ( - field_obj.data_key if field_obj.data_key is not None else field_name - ) + data_key = field_obj.data_key if many: for idx, item in enumerate(data): try: - value = item[field_obj.attribute or field_name] + value = item[field_obj.attribute] except KeyError: pass else: @@ -1131,7 +1132,7 @@ def _invoke_field_validators(self, *, error_store: ErrorStore, data, many: bool) data[idx].pop(field_name, None) else: try: - value = data[field_obj.attribute or field_name] + value = data[field_obj.attribute] except KeyError: pass else: diff --git a/tests/test_fields.py b/tests/test_fields.py index e7f552bd1..9112c1e1a 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -32,7 +32,7 @@ def test_repr(self): default = "œ∑´" field = fields.Field(dump_default=default, attribute=None) assert repr(field) == ( - "