Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add nnx.Module.perturb #4515

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@
from .variablelib import BatchStat as BatchStat
from .variablelib import Cache as Cache
from .variablelib import Intermediate as Intermediate
from .variablelib import Perturbation as Perturbation
from .variablelib import Variable as Variable
from .variablelib import VariableState as VariableState
from .variablelib import VariableMetadata as VariableMetadata
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:
nnx_attrs[attr_name] = _recursive_merge(nnx_attrs[attr_name], value)
else:
nnx_attrs[attr_name] = value # it's a variable on this layer
return nnx_attrs
return dict(nnx_attrs)


def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
Expand Down
74 changes: 74 additions & 0 deletions flax/nnx/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import typing as tp
from functools import partial

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from flax.nnx import (
Expand Down Expand Up @@ -183,6 +185,78 @@ def sow(
reduced_value = reduce_fn(init_fn(), value)
setattr(self, name, variable_type(reduced_value))

def perturb(
self,
name: str,
value: tp.Any,
variable_type: tp.Type[variableslib.Variable[tp.Any]] = variableslib.Perturbation,
):
"""Add an zero-value variable ("perturbation") to the intermediate value.

The gradient of ``value`` would be the same as the gradient of this
perturbation variable. Therefore, if you define your loss function with
both params and perturbations as standalone arguments, you can get the
intermediate gradients of ``value`` by running ``jax.grad`` on the
perturbation variable.

Since the shape of the perturbation value depends on the shape of the input,
a perturbation variable is only created after you run a sample input through
the model once.

.. note::
This creates extra dummy variables of the same size as ``value``, thus
occupies more memory. Use it only to debug gradients in training.

Example usage::

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x):
... x = self.linear1(x)
... x = self.perturb('xgrad', x)
... x = self.linear2(x)
... return x

>>> x = jnp.ones((1, 2))
>>> y = jnp.ones((1, 4))
>>> model = Model(rngs=nnx.Rngs(0))
>>> assert not hasattr(model, 'xgrad') # perturbation requires a sample input run
>>> _ = model(x)
>>> assert model.xgrad.value.shape == (1, 3) # same as the intermediate value

>>> # Take gradients on the Param and Perturbation variables
>>> @nnx.grad(argnums=nnx.DiffState(argnum=0, filter=nnx.Any(nnx.Param, nnx.Perturbation)))
... def grad_loss(model, inputs, targets):
... preds = model(inputs)
... return jnp.square(preds - targets).mean()

>>> intm_grads = grad_loss(model, x, y)
>>> # `intm_grads.xgrad.value` is the intermediate gradient
>>> assert not jnp.array_equal(intm_grads.xgrad.value, jnp.zeros((1, 3)))

Args:
name: A string denoting the ``Module`` attribute name for the
perturbation value.
value: The value to take intermediate gradient.
variable_type: The :class:`Variable` type for the stored perturbation.
Defaulted at :class:`nnx.Perturbation`.
"""
if not hasattr(self, name):
zeros = jax.tree.map(jnp.zeros_like, value)
setattr(self, name, variable_type(zeros))
old_value = getattr(self, name)
if not isinstance(old_value, variable_type):
raise ValueError(
f"Expected '{name}' to be of type '{variable_type.__name__}', "
f"got '{type(old_value).__name__}'"
)
return old_value.value + value

def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]:
"""Recursively iterates over all nested :class:`Module`'s of the current Module, including
the current Module.
Expand Down
35 changes: 34 additions & 1 deletion flax/nnx/variablelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,38 @@ class Intermediate(Variable[A]):
pass


class Perturbation(Intermediate[A]):
""":class:`Variable` type that is typically used for
:func:`Module.perturb`::

>>> from flax import nnx
>>> import jax, jax.numpy as jnp

>>> class Model(nnx.Module):
... def __init__(self, rngs):
... self.linear1 = nnx.Linear(2, 3, rngs=rngs)
... self.linear2 = nnx.Linear(3, 4, rngs=rngs)
... def __call__(self, x):
... x = self.linear1(x)
... x = self.perturb('i', x)
... x = self.linear2(x)
... return x
>>> model = Model(rngs=nnx.Rngs(0))

>>> x = jnp.ones((1, 2))
>>> y = model(x)
>>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Perturbation))
State({
'i': VariableState(
type=Perturbation,
value=(1, 3)
)
})
"""

pass


class VariableState(tp.Generic[A], reprlib.Representable):
__slots__ = ('type', 'value', '_var_metadata')
type: type[Variable[A]]
Expand Down Expand Up @@ -1011,4 +1043,5 @@ def register_variable_name_type_pair(name, typ, overwrite = False):
register_variable_name_type_pair('params', Param)
register_variable_name_type_pair('batch_stats', BatchStat)
register_variable_name_type_pair('cache', Cache)
register_variable_name_type_pair('intermediates', Intermediate)
register_variable_name_type_pair('intermediates', Intermediate)
register_variable_name_type_pair('perturbations', Perturbation)
37 changes: 37 additions & 0 deletions tests/nnx/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,43 @@ def __call__(self, x):
with self.assertRaisesRegex(ValueError, 'to be of type'):
m(2)

def test_perturb_basic(self):
class Foo(nnx.Module):
def __init__(self, rngs):
self.linear = nnx.Linear(10, 10, rngs=rngs)

def __call__(self, x):
x = self.linear(x)
x = self.perturb('before_multiply', x)
x = 4 * x
x = self.perturb('after_multiply', x)
return x

model = Foo(rngs=nnx.Rngs(0))
# Perturbations are not created in init time. It needs some sample input.
self.assertFalse(hasattr(model, 'before_multiply'))
self.assertFalse(hasattr(model, 'after_multiply'))

x = jax.random.uniform(jax.random.key(1), shape=(10,))
y = jax.random.uniform(jax.random.key(2), shape=(10,))
model(x)
np.testing.assert_array_equal(model.before_multiply, jnp.zeros_like(x))
np.testing.assert_array_equal(model.after_multiply, jnp.zeros_like(x))

take_gradient_filter = nnx.Any(nnx.Param, nnx.Perturbation)
@nnx.grad(argnums=nnx.DiffState(argnum=0, filter=take_gradient_filter))
def grad_loss(model, inputs, targets):
preds = model(inputs)
return jnp.square(preds - targets).mean()
intm_grads = grad_loss(model, x, y)

# Gradient should not be zero
self.assertFalse(jnp.array_equal(
intm_grads.before_multiply.value, jnp.zeros_like(x)))
# activation * 4 so reverse gradient also * 4
np.testing.assert_allclose(intm_grads.after_multiply.value * 4,
intm_grads.before_multiply.value)

def test_update_static_state_submodules(self):
class Bar(nnx.Module):
def __init__(self) -> None:
Expand Down
Loading