Skip to content

Commit

Permalink
Fix generic dataclasses with bound parameters.
Browse files Browse the repository at this point in the history
This alters the way in which mappable dataclasses are created in order to fix a
crash when a mappable, frozen generic dataclass is instantiated with a bound
type parameter.

PiperOrigin-RevId: 521409933
  • Loading branch information
stompchicken authored and ChexDev committed Apr 24, 2023
1 parent fcf33ee commit 0f4f0a6
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 43 deletions.
121 changes: 83 additions & 38 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import collections
import dataclasses
import functools
import inspect

from absl import logging
import jax
Expand All @@ -27,32 +28,55 @@
_RESERVED_DCLS_FIELD_NAMES = frozenset(("from_tuple", "replace", "to_tuple"))


def mappable_dataclass(cls):
"""Exposes dataclass as ``collections.abc.Mapping`` descendent.
def _make_mappable(cls):
"""Create type that implements and inherits from ``collections.abc.Mapping``.
Allows to traverse dataclasses in methods from `dm-tree` library.
Note that this does not require the class to be a dataclass, as it is supposed
to be applied before creating the dataclass.
NOTE: changes dataclasses constructor to dict-type
(i.e. positional args aren't supported; however can use generators/iterables).
Allows to traverse dataclasses in methods from `dm-tree` library.
Args:
cls: A dataclass to mutate.
cls: A class to use as a base for the new type.
Returns:
Mutated dataclass implementing ``collections.abc.Mapping`` interface.
type implementing and inheriting from ``collections.abc.Mapping``.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).")

# Define methods for compatibility with `collections.abc.Mapping`.
setattr(cls, "__getitem__", lambda self, x: self.__dict__[x])
setattr(cls, "__len__", lambda self: len(self.__dict__))
setattr(cls, "__iter__", lambda self: iter(self.__dict__))

# Update constructor.
orig_init = cls.__init__
all_fields = set(f.name for f in cls.__dataclass_fields__.values())
init_fields = [f.name for f in cls.__dataclass_fields__.values() if f.init]
# Update base class to derive from Mapping
dct = dict(cls.__dict__)
if "__dict__" in dct:
dct.pop("__dict__") # Avoid self-references.

# Remove object from the sequence of base classes. Deriving from both Mapping
# and object will cause a failure to create a MRO for the updated class
bases = tuple(b for b in cls.__bases__ if b != object)
return type(cls.__name__, bases + (collections.abc.Mapping,), dct)


def _convert_kw_only_dataclass_init(dcls):
"""Create wrapped initializer that converts everything to keyword arguments.
This should be equivalent to passing `kw_only=True` when creating the
dataclass in Python <= 3.10.
Args:
dcls: the dataclass to take the constructor from.
Returns:
Initializer wrapping the original initializer but which requires
keyword-only arguments.
Throws:
ValueError: if all required arguments are not provided as keyword-only.
"""
orig_init = dcls.__init__
all_fields = set(f.name for f in dcls.__dataclass_fields__.values())
init_fields = [f.name for f in dcls.__dataclass_fields__.values() if f.init]

@functools.wraps(orig_init)
def new_init(self, *orig_args, **orig_kwargs):
Expand All @@ -69,17 +93,28 @@ def new_init(self, *orig_args, **orig_kwargs):
valid_kwargs = {k: v for k, v in all_kwargs.items() if k in init_fields}
orig_init(self, **valid_kwargs)

cls.__init__ = new_init
return new_init

# Update base class to derive from Mapping
dct = dict(cls.__dict__)
if "__dict__" in dct:
dct.pop("__dict__") # Avoid self-references.

# Remove object from the sequence of base classes. Deriving from both Mapping
# and object will cause a failure to create a MRO for the updated class
bases = tuple(b for b in cls.__bases__ if b != object)
cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct)
def mappable_dataclass(cls):
"""Exposes dataclass as ``collections.abc.Mapping`` descendent.
Allows to traverse dataclasses in methods from `dm-tree` library.
NOTE: changes dataclasses constructor to dict-type
(i.e. positional args aren't supported; however can use generators/iterables).
Args:
cls: A dataclass to mutate.
Returns:
Mutated dataclass implementing ``collections.abc.Mapping`` interface.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).")

cls = _make_mappable(cls)
cls.__init__ = _convert_kw_only_dataclass_init(cls)
return cls


Expand Down Expand Up @@ -159,37 +194,40 @@ def __init__(
def __call__(self, cls):
"""Forwards class to dataclasses's wrapper and registers it with JAX."""

if self.mappable_dataclass:
cls = _make_mappable(cls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
for attr in ("values", "keys", "get", "items"):
setattr(cls, attr, None) # redefine to avoid AttributeError on delattr
delattr(cls, attr) # delete

# Remove once https://github.com/python/cpython/pull/24484 is merged.
for base in cls.__bases__:
if (dataclasses.is_dataclass(base) and
getattr(base, "__dataclass_params__").frozen and not self.frozen):
raise TypeError("cannot inherit non-frozen dataclass from a frozen one")

# Check for invalid field names.
annotations = inspect.get_annotations(cls)
fields_names = set(name for name in annotations.keys())
invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES)
if invalid_fields:
raise ValueError(f"The following dataclass fields are disallowed: "
f"{invalid_fields} ({cls}).")

# pytype: disable=wrong-keyword-args
dcls = dataclasses.dataclass(
cls,
init=self.init,
repr=self.repr,
eq=self.eq,
order=self.order,
# kw_only=self.mappable_dataclass,
unsafe_hash=self.unsafe_hash,
frozen=self.frozen)
# pytype: enable=wrong-keyword-args

fields_names = set(f.name for f in dataclasses.fields(dcls))
invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES)
if invalid_fields:
raise ValueError(f"The following dataclass fields are disallowed: "
f"{invalid_fields} ({dcls}).")

if self.mappable_dataclass:
dcls = mappable_dataclass(dcls)
# We remove `collection.abc.Mapping` mixin methods here to allow
# fields with these names.
for attr in ("values", "keys", "get", "items"):
setattr(dcls, attr, None) # redefine
delattr(dcls, attr) # delete

def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))

Expand All @@ -212,6 +250,9 @@ def _setstate(self, state):
self.__dict__.update(state)

orig_init = dcls.__init__
is_mappable_dataclass = self.mappable_dataclass
if self.mappable_dataclass:
kw_only_init = _convert_kw_only_dataclass_init(dcls)

# Patch object's __init__ such that the class is registered on creation if
# it is not registered on deserialization.
Expand All @@ -220,7 +261,11 @@ def _init(self, *args, **kwargs):
if not class_self.registered:
register_dataclass_type_with_jax_tree_util(dcls)
class_self.registered = True
return orig_init(self, *args, **kwargs)

if is_mappable_dataclass:
return kw_only_init(self, *args, **kwargs)
else:
return orig_init(self, *args, **kwargs)

setattr(dcls, "from_tuple", _from_tuple)
setattr(dcls, "to_tuple", _to_tuple)
Expand Down
14 changes: 9 additions & 5 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,6 @@ class ValidMappable:
get: int

with self.assertRaisesRegex(ValueError, 'dataclass fields are disallowed'):

@chex_dataclass(mappable_dataclass=True)
class InvalidMappable:
get: int
Expand Down Expand Up @@ -571,19 +570,24 @@ class Bar:
self.assertLen(jax.tree_util.tree_flatten(Bar())[0], 2)

@parameterized.named_parameters(
('mappable', True),
('not_mappable', False),
('mappable_frozen', True, True),
('not_mappable_frozen', False, True),
('mappable_not_frozen', True, False),
('not_mappable_not_frozen', False, False),
)
def test_generic_dataclass(self, mappable):
def test_generic_dataclass(self, mappable, frozen):
T = TypeVar('T')

@chex_dataclass(mappable_dataclass=mappable)
@chex_dataclass(mappable_dataclass=mappable, frozen=frozen)
class GenericDataclass(Generic[T]):
a: T # pytype: disable=invalid-annotation # enable-bare-annotations

obj = GenericDataclass(a=np.array([1.0, 1.0]))
asserts.assert_tree_all_close(obj.a, 1.0)

obj = GenericDataclass[np.array](a=np.array([1.0, 1.0]))
asserts.assert_tree_all_close(obj.a, 1.0)

def test_mappable_eq_override(self):

@chex_dataclass(mappable_dataclass=True)
Expand Down

0 comments on commit 0f4f0a6

Please sign in to comment.