diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index c4e8426..1d3c8f0 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -279,10 +279,10 @@ def _flatten_with_path(dcls): path = [] keys = [] for k, v in sorted(dcls.__dict__.items()): + keys.append(k) # generate same aux data as flatten without path k = jax.tree_util.GetAttrKey(k) path.append((k, v)) - keys.append(k) - return path, keys + return path, tuple(keys) @functools.cache