Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch LayerStackGetAttrKey to a custom dataclass type. #96

Merged
merged 1 commit into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion penzai/core/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,24 @@

from __future__ import annotations

import dataclasses
from typing import Any, Optional

import jax

PyTreeDef = jax.tree_util.PyTreeDef


@dataclasses.dataclass(frozen=True)
class CustomGetAttrKey:
"""Subclass-friendly variant of jax.tree_util.GetAttrKey."""

name: str

def __str__(self):
return f".{self.name}"


def tree_flatten_exactly_one_level(
tree: Any,
) -> Optional[tuple[list[tuple[Any, Any]], PyTreeDef]]:
Expand Down Expand Up @@ -66,7 +77,10 @@ def pretty_keystr(keypath: tuple[Any, ...], tree: Any) -> str:
parts = []
for key in keypath:
if isinstance(
key, jax.tree_util.GetAttrKey | jax.tree_util.FlattenedIndexKey
key,
jax.tree_util.GetAttrKey
| jax.tree_util.FlattenedIndexKey
| CustomGetAttrKey,
):
parts.extend(("/", type(tree).__name__))
split = tree_flatten_exactly_one_level(tree)
Expand Down
5 changes: 3 additions & 2 deletions penzai/nn/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from __future__ import annotations

import collections
from collections.abc import Hashable
import copy
import dataclasses
import enum
from typing import Any, Callable, Hashable
from typing import Any, Callable

import jax
from penzai.core import named_axes
Expand All @@ -39,7 +40,7 @@ class LayerStackVarBehavior(enum.Enum):


@dataclasses.dataclass(frozen=True)
class LayerStackGetAttrKey(jax.tree_util.GetAttrKey):
class LayerStackGetAttrKey(pz_tree_util.CustomGetAttrKey):
"""GetAttrKey for LayerStack with extra metadata.

This allows us to identify whether a given PyTree leaf is contained inside a
Expand Down
14 changes: 13 additions & 1 deletion tests/nn/layer_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any
from absl.testing import absltest
import chex
import collections
import jax
from penzai import pz

Expand Down Expand Up @@ -155,7 +156,18 @@ def builder(init_base_rng, some_value):
unbound_layer, layer_vars = pz.unbind_variables(layer)
unbound_slot_layer, slot_layer_vars = pz.unbind_variables(slot_layer)

chex.assert_trees_all_equal(unbound_layer, unbound_slot_layer)
# Check as dictionaries to avoid limitations of chex:
unbound_layer_leaves, unbound_layer_treedef = (
jax.tree_util.tree_flatten_with_path(unbound_layer)
)
unbound_slot_layer_leaves, unbound_slot_layer_treedef = (
jax.tree_util.tree_flatten_with_path(unbound_slot_layer)
)
self.assertEqual(unbound_layer_treedef, unbound_slot_layer_treedef)
chex.assert_trees_all_equal(
collections.OrderedDict(unbound_layer_leaves),
collections.OrderedDict(unbound_slot_layer_leaves),
)

slot_layer_vars_dict = {var.label: var for var in slot_layer_vars}
for var in layer_vars:
Expand Down
Loading