diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index 5adf3556..4dbc0c76 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -156,6 +156,7 @@ from .variablelib import BatchStat as BatchStat from .variablelib import Cache as Cache from .variablelib import Intermediate as Intermediate +from .variablelib import Perturbation as Perturbation from .variablelib import Variable as Variable from .variablelib import VariableState as VariableState from .variablelib import VariableMetadata as VariableMetadata diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index b3392c86..bcbd1c25 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -140,7 +140,7 @@ def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]: nnx_attrs[attr_name] = _recursive_merge(nnx_attrs[attr_name], value) else: nnx_attrs[attr_name] = value # it's a variable on this layer - return nnx_attrs + return dict(nnx_attrs) def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: diff --git a/flax/nnx/module.py b/flax/nnx/module.py index b07efa77..39ee7e3b 100644 --- a/flax/nnx/module.py +++ b/flax/nnx/module.py @@ -17,6 +17,8 @@ import typing as tp from functools import partial +import jax +import jax.numpy as jnp import jax.tree_util as jtu from flax.nnx import ( @@ -183,6 +185,78 @@ def sow( reduced_value = reduce_fn(init_fn(), value) setattr(self, name, variable_type(reduced_value)) + def perturb( + self, + name: str, + value: tp.Any, + variable_type: tp.Type[variableslib.Variable[tp.Any]] = variableslib.Perturbation, + ): + """Add an zero-value variable ("perturbation") to the intermediate value. + + The gradient of ``value`` would be the same as the gradient of this + perturbation variable. Therefore, if you define your loss function with + both params and perturbations as standalone arguments, you can get the + intermediate gradients of ``value`` by running ``jax.grad`` on the + perturbation variable. + + Since the shape of the perturbation value depends on the shape of the input, + a perturbation variable is only created after you run a sample input through + the model once. + + .. note:: + This creates extra dummy variables of the same size as ``value``, thus + occupies more memory. Use it only to debug gradients in training. + + Example usage:: + + >>> from flax import nnx + >>> import jax.numpy as jnp + + >>> class Model(nnx.Module): + ... def __init__(self, rngs): + ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) + ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) + ... def __call__(self, x): + ... x = self.linear1(x) + ... x = self.perturb('xgrad', x) + ... x = self.linear2(x) + ... return x + + >>> x = jnp.ones((1, 2)) + >>> y = jnp.ones((1, 4)) + >>> model = Model(rngs=nnx.Rngs(0)) + >>> assert not hasattr(model, 'xgrad') # perturbation requires a sample input run + >>> _ = model(x) + >>> assert model.xgrad.value.shape == (1, 3) # same as the intermediate value + + >>> # Take gradients on the Param and Perturbation variables + >>> @nnx.grad(argnums=nnx.DiffState(argnum=0, filter=nnx.Any(nnx.Param, nnx.Perturbation))) + ... def grad_loss(model, inputs, targets): + ... preds = model(inputs) + ... return jnp.square(preds - targets).mean() + + >>> intm_grads = grad_loss(model, x, y) + >>> # `intm_grads.xgrad.value` is the intermediate gradient + >>> assert not jnp.array_equal(intm_grads.xgrad.value, jnp.zeros((1, 3))) + + Args: + name: A string denoting the ``Module`` attribute name for the + perturbation value. + value: The value to take intermediate gradient. + variable_type: The :class:`Variable` type for the stored perturbation. + Defaulted at :class:`nnx.Perturbation`. + """ + if not hasattr(self, name): + zeros = jax.tree.map(jnp.zeros_like, value) + setattr(self, name, variable_type(zeros)) + old_value = getattr(self, name) + if not isinstance(old_value, variable_type): + raise ValueError( + f"Expected '{name}' to be of type '{variable_type.__name__}', " + f"got '{type(old_value).__name__}'" + ) + return old_value.value + value + def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: """Recursively iterates over all nested :class:`Module`'s of the current Module, including the current Module. diff --git a/flax/nnx/variablelib.py b/flax/nnx/variablelib.py index 83da9c7a..f1cca3f8 100644 --- a/flax/nnx/variablelib.py +++ b/flax/nnx/variablelib.py @@ -747,6 +747,38 @@ class Intermediate(Variable[A]): pass +class Perturbation(Intermediate[A]): + """:class:`Variable` type that is typically used for + :func:`Module.perturb`:: + + >>> from flax import nnx + >>> import jax, jax.numpy as jnp + + >>> class Model(nnx.Module): + ... def __init__(self, rngs): + ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) + ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) + ... def __call__(self, x): + ... x = self.linear1(x) + ... x = self.perturb('i', x) + ... x = self.linear2(x) + ... return x + >>> model = Model(rngs=nnx.Rngs(0)) + + >>> x = jnp.ones((1, 2)) + >>> y = model(x) + >>> jax.tree.map(jnp.shape, nnx.state(model, nnx.Perturbation)) + State({ + 'i': VariableState( + type=Perturbation, + value=(1, 3) + ) + }) + """ + + pass + + class VariableState(tp.Generic[A], reprlib.Representable): __slots__ = ('type', 'value', '_var_metadata') type: type[Variable[A]] @@ -1011,4 +1043,5 @@ def register_variable_name_type_pair(name, typ, overwrite = False): register_variable_name_type_pair('params', Param) register_variable_name_type_pair('batch_stats', BatchStat) register_variable_name_type_pair('cache', Cache) -register_variable_name_type_pair('intermediates', Intermediate) \ No newline at end of file +register_variable_name_type_pair('intermediates', Intermediate) +register_variable_name_type_pair('perturbations', Perturbation) \ No newline at end of file diff --git a/tests/nnx/module_test.py b/tests/nnx/module_test.py index 64928f46..6ecf1e80 100644 --- a/tests/nnx/module_test.py +++ b/tests/nnx/module_test.py @@ -322,6 +322,43 @@ def __call__(self, x): with self.assertRaisesRegex(ValueError, 'to be of type'): m(2) + def test_perturb_basic(self): + class Foo(nnx.Module): + def __init__(self, rngs): + self.linear = nnx.Linear(10, 10, rngs=rngs) + + def __call__(self, x): + x = self.linear(x) + x = self.perturb('before_multiply', x) + x = 4 * x + x = self.perturb('after_multiply', x) + return x + + model = Foo(rngs=nnx.Rngs(0)) + # Perturbations are not created in init time. It needs some sample input. + self.assertFalse(hasattr(model, 'before_multiply')) + self.assertFalse(hasattr(model, 'after_multiply')) + + x = jax.random.uniform(jax.random.key(1), shape=(10,)) + y = jax.random.uniform(jax.random.key(2), shape=(10,)) + model(x) + np.testing.assert_array_equal(model.before_multiply, jnp.zeros_like(x)) + np.testing.assert_array_equal(model.after_multiply, jnp.zeros_like(x)) + + take_gradient_filter = nnx.Any(nnx.Param, nnx.Perturbation) + @nnx.grad(argnums=nnx.DiffState(argnum=0, filter=take_gradient_filter)) + def grad_loss(model, inputs, targets): + preds = model(inputs) + return jnp.square(preds - targets).mean() + intm_grads = grad_loss(model, x, y) + + # Gradient should not be zero + self.assertFalse(jnp.array_equal( + intm_grads.before_multiply.value, jnp.zeros_like(x))) + # activation * 4 so reverse gradient also * 4 + np.testing.assert_allclose(intm_grads.after_multiply.value * 4, + intm_grads.before_multiply.value) + def test_update_static_state_submodules(self): class Bar(nnx.Module): def __init__(self) -> None: