Skip to content

Commit

Permalink
Tweaks to core variable API.
Browse files Browse the repository at this point in the history
- Changes "frozen" to "freeze" in `pz.unbind_variables`
- Adds an "unfreeze_as_copy" helper argument to `pz.bind_variables`
- Explicit error when returning a variable from `pz.variable_jit` (which isn't supported)

PiperOrigin-RevId: 647186170
  • Loading branch information
danieldjohnson authored and Penzai Developers committed Jun 27, 2024
1 parent f78dc22 commit 84bde94
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 21 deletions.
10 changes: 5 additions & 5 deletions docs/guides/howto_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ To extract model parameters, you can use `pz.unbind_params`, which extracts and
unbound_model, params = pz.unbind_params(model)
```

After extracting parameters, you may also want to freeze them, which produces an immutable snapshot of the current value of each parameter (as an instance of `ParameterValue`). You can do this by calling `.freeze()` on each parameter, by using `pz.freeze_params` to freeze all parameters in a collection, or by using `pz.unbind_params(model, frozen=True)`. Frozen parameters are ordinary JAX PyTrees, making them safe to use across JAX transformation boundaries.
After extracting parameters, you may also want to freeze them, which produces an immutable snapshot of the current value of each parameter (as an instance of `ParameterValue`). You can do this by calling `.freeze()` on each parameter, by using `pz.freeze_params` to freeze all parameters in a collection, or by using `pz.unbind_params(model, freeze=True)`. Frozen parameters are ordinary JAX PyTrees, making them safe to use across JAX transformation boundaries.

Both mutable `Parameter` instances and frozen `ParameterValue` instances can be substituted back into a model with `ParameterSlot`s `pz.bind_variables`. A common pattern is to unbind and freeze `Parameter`s before a JAX transformation, and then re-bind their frozen values inside the function being transformed.

Expand All @@ -283,7 +283,7 @@ def my_loss(params, unbound_model):
loss = # (... compute the loss ...)
return loss

unbound_model, frozen_params = pz.unbind_params(model, frozen=True)
unbound_model, frozen_params = pz.unbind_params(model, freeze=True)
grads = jax.grad(my_loss, argnums=0)(frozen_params, unbound_model)
```

Expand All @@ -295,7 +295,7 @@ def my_func(params, unbound_model):
rebound_model = pz.bind_variables(unbound_model, params)
return rebound_model(...) # call it with some arguments

unbound_model, frozen_params = pz.unbind_params(model, frozen=True)
unbound_model, frozen_params = pz.unbind_params(model, freeze=True)

# Build your input perturbations somehow
perturbations = jax.tree_util.tree_map(some_func, frozen_params)
Expand All @@ -308,10 +308,10 @@ Some Penzai layers keep track of mutable `pz.StateVariable` instances and update

Outside of JAX transformations, you can usually just mutate state variables normally. However, running stateful operations inside JAX transformations can require some care. Additionally, it's sometimes useful to take a snapshot of the state of all variables in a model.

When working with a model that uses state variables, you can unbind the state variables using `pz.unbind_state_vars`, and optionally freeze them using `pz.freeze_state_vars` (or unbind with `frozen=True`), similar to the corresponding methods for `Parameter`s. This allows you to extract an immutable view of the model state that is safe to manipulate in JAX, e.g. via
When working with a model that uses state variables, you can unbind the state variables using `pz.unbind_state_vars`, and optionally freeze them using `pz.freeze_state_vars` (or unbind with `freeze=True`), similar to the corresponding methods for `Parameter`s. This allows you to extract an immutable view of the model state that is safe to manipulate in JAX, e.g. via

```
stateless_model, frozen_states = pz.unbind_state_vars(model, frozen=True)
stateless_model, frozen_states = pz.unbind_state_vars(model, freeze=True)
```

Every subclass of `Layer` implements the method `stateless_call`, which takes frozen state variables as input and returns updated state variables as output:
Expand Down
2 changes: 1 addition & 1 deletion notebooks/v2_how_to_think_in_penzai.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@
"\n",
"# Unbind and freeze state vars:\n",
"unbound_frozen_model, state_vars = pz.unbind_state_vars(\n",
" frozen_param_model, frozen=True\n",
" frozen_param_model, freeze=True\n",
")\n",
"state_var_values = pz.freeze_state_vars(state_vars)\n",
"\n",
Expand Down
51 changes: 38 additions & 13 deletions penzai/experimental/v2/core/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class AbstractVariableSlot(struct.Struct, abc.ABC):
def unbind_variables(
tree: Any,
predicate: Callable[[AbstractVariable], bool] | None = None,
frozen: Literal[False] = False,
freeze: Literal[False] = False,
) -> tuple[Any, tuple[AbstractVariable, ...]]:
...

Expand All @@ -168,15 +168,15 @@ def unbind_variables(
tree: Any,
predicate: Callable[[AbstractVariable], bool] | None = None,
*,
frozen: Literal[True],
freeze: Literal[True],
) -> tuple[Any, tuple[AbstractVariableValue, ...]]:
...


def unbind_variables(
tree: Any,
predicate: Callable[[AbstractVariable], bool] | None = None,
frozen: bool = False,
freeze: bool = False,
) -> tuple[Any, tuple[AbstractVariable | AbstractVariableValue, ...]]:
"""Unbinds variables from a pytree, inserting variable slots in their place.
Expand All @@ -194,7 +194,7 @@ def unbind_variables(
raised.
predicate: A function that returns True for variables that should be
extracted. If None, all variables will be extracted.
frozen: Whether to return frozen variables instead of mutable variables.
freeze: Whether to return frozen variables instead of mutable variables.
Returns:
A tuple ``(tree_with_slots, variables)``, where ``tree_with_slots`` is
Expand Down Expand Up @@ -248,7 +248,7 @@ def unbind_variables(
else:
new_leaves.append(leaf)

if frozen:
if freeze:
extracted = tuple(var.freeze() for var in variable_dict.values())
else:
extracted = tuple(variable_dict.values())
Expand All @@ -260,6 +260,7 @@ def bind_variables(
tree: Any,
variables: Iterable[AbstractVariable | AbstractVariableValue],
allow_unused: bool = False,
unfreeze_as_copy: bool = False,
) -> Any:
"""Binds variables (mutable or frozen) into the variable slots in a pytree.
Expand All @@ -271,6 +272,9 @@ def bind_variables(
variables: The collection of variables to insert.
allow_unused: Whether to ignore variables that do not have any matching slot
(in which case they will not be inserted).
unfreeze_as_copy: Whether to unfreeze variable values before inserting them,
producing a new mutable copy of each input variable. If True, all input
variables must be `AbstractVariableValue`s.
Returns:
A copy of ``tree`` with variables re-inserted.
Expand All @@ -279,6 +283,18 @@ def bind_variables(
tree, is_leaf=lambda l: isinstance(l, AbstractVariableSlot)
)

if unfreeze_as_copy:
orig_variables = variables
variables = []
for var in orig_variables:
if isinstance(var, AbstractVariableValue):
variables.append(var.unfreeze_as_copy())
else:
raise ValueError(
"unfreeze_as_copy=True is only allowed if all variables are"
" variable values (e.g. ParameterValue or StateVariableValue)."
)

substitution = {}
for var in variables:
var_slot = var.get_slot()
Expand Down Expand Up @@ -390,6 +406,15 @@ def inner_fun(*args, **kwargs):
mut_vars = [var.unfreeze_as_copy() for var in frozen_variables]
(rebound_args, rebound_kwargs) = bind_variables((args, kwargs), mut_vars)
result = fun(*rebound_args, **rebound_kwargs)
_, bad_vars = unbind_variables(result)
if bad_vars:
raise ValueError(
"Returning a variable from a function transformed by pz.variable_jit"
" is not allowed. To create new variables under jax.jit, you should"
" instead return `pz.unbind_variables(..., frozen=True)`, then"
" rebuild the new variables after with `pz.bind_variables(...,"
f" unfreeze_as_copy=True)`.\nFound variables: {bad_vars}"
)
return result, [var.freeze() for var in mut_vars]

inner_fun.__signature__ = new_sig
Expand Down Expand Up @@ -782,7 +807,7 @@ def _type_filtered_predicate(
def unbind_params(
tree: Any,
predicate: Callable[[Parameter], bool] | None = None,
frozen: Literal[False] = False,
freeze: Literal[False] = False,
) -> tuple[Any, tuple[Parameter, ...]]:
...

Expand All @@ -792,21 +817,21 @@ def unbind_params(
tree: Any,
predicate: Callable[[Parameter], bool] | None = None,
*,
frozen: Literal[True],
freeze: Literal[True],
) -> tuple[Any, tuple[ParameterValue, ...]]:
...


def unbind_params(
tree: Any,
predicate: Callable[[Parameter], bool] | None = None,
frozen: bool = False,
freeze: bool = False,
) -> tuple[Any, tuple[Parameter | ParameterValue, ...]]:
r"""Version of `unbind_variables` that only extracts `Parameter`\ s."""
return unbind_variables( # type: ignore
tree,
predicate=_type_filtered_predicate(predicate, Parameter),
frozen=frozen,
freeze=freeze,
)


Expand All @@ -824,7 +849,7 @@ def freeze_params(
def unbind_state_vars(
tree: Any,
predicate: Callable[[StateVariable], bool] | None = None,
frozen: Literal[False] = False,
freeze: Literal[False] = False,
) -> tuple[Any, tuple[StateVariable, ...]]:
...

Expand All @@ -834,21 +859,21 @@ def unbind_state_vars(
tree: Any,
predicate: Callable[[StateVariable], bool] | None = None,
*,
frozen: Literal[True],
freeze: Literal[True],
) -> tuple[Any, tuple[StateVariableValue, ...]]:
...


def unbind_state_vars(
tree: Any,
predicate: Callable[[StateVariable], bool] | None = None,
frozen: bool = False,
freeze: bool = False,
) -> tuple[Any, tuple[StateVariable | StateVariableValue, ...]]:
r"""Version of `unbind_variables` that only extracts `StateVariable`\ s."""
return unbind_variables( # type: ignore
tree,
predicate=_type_filtered_predicate(predicate, StateVariable),
frozen=frozen,
freeze=freeze,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,7 @@
"""Concrete transformer variants."""

from . import gemma
from . import gpt_neox
from . import llama
from . import llamalike_common
from . import mistral
2 changes: 1 addition & 1 deletion penzai/experimental/v2/nn/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def go(rng):
**builder_kwargs,
)
unbound_sublayer, var_values = variables.unbind_variables(
sublayer, frozen=True
sublayer, freeze=True
)
if any(
not isinstance(leaf, named_axes.NamedArrayBase)
Expand Down
39 changes: 38 additions & 1 deletion tests/experimental/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_unbind_variables(self):

with self.subTest("unbind_frozen"):
thing_with_slots, vars_list = variables.unbind_variables(
thing_with_vars, frozen=True
thing_with_vars, freeze=True
)
self.assertEqual(
thing_with_slots,
Expand Down Expand Up @@ -168,6 +168,33 @@ def test_bind_variables(self):
},
)

def test_bind_and_unfreeze(self):
var_1 = variables.StateVariableValue(value=1, label="var_1")
var_2 = variables.ParameterValue(value=2, label="var_2")
thing_with_slots = {
"var_1": variables.StateVariableSlot("var_1"),
"var_2": variables.ParameterSlot("var_2"),
"inner": [
variables.StateVariableSlot("var_1"),
variables.ParameterSlot("var_2"),
],
"something_else": "something else",
}
unfrozen = variables.bind_variables(
thing_with_slots, [var_1, var_2], unfreeze_as_copy=True
)
self.assertIsInstance(unfrozen["var_1"], variables.StateVariable)
self.assertIsInstance(unfrozen["var_2"], variables.Parameter)
self.assertEqual(
unfrozen,
{
"var_1": unfrozen["var_1"],
"var_2": unfrozen["var_2"],
"inner": [unfrozen["var_1"], unfrozen["var_2"]],
"something_else": "something else",
},
)

def test_variable_freeze_unfreeze(self):
var_1 = variables.StateVariable(
value=1.0, label="var_1", metadata={"foo": "bar"}
Expand Down Expand Up @@ -282,6 +309,16 @@ def my_var_fun(thing_with_vars, increment, something_else):
self.assertEqual(thing_with_vars["var_2"].value, 1002)
self.assertEqual(result, 1103)

def test_variable_jit_disallows_returning_vars(self):

def bad(x):
return [1, 2, variables.Parameter(label="foo", value=x)]

with self.assertRaisesWithPredicateMatch(
ValueError, lambda exc: "Returning a variable" in str(exc)
):
variables.variable_jit(bad)(10)


if __name__ == "__main__":
absltest.main()

0 comments on commit 84bde94

Please sign in to comment.