From f90cba9b41afce36d6d646272fe33b9df99fd656 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Mon, 25 Nov 2024 16:17:35 +0800 Subject: [PATCH] fix `vector_grad` bugs (#68) --- brainunit/autograd/_vector_grad.py | 6 +++--- brainunit/autograd/_vector_grad_test.py | 15 +++++++++++++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/brainunit/autograd/_vector_grad.py b/brainunit/autograd/_vector_grad.py index c808071..a4d7c7d 100644 --- a/brainunit/autograd/_vector_grad.py +++ b/brainunit/autograd/_vector_grad.py @@ -57,9 +57,11 @@ def grad_fun(*args, **kwargs): assert len(leaves) == 1, 'The function must return a single array when unit_aware is True.' tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves]) grads = vjp_fn(tangents) + if isinstance(argnums, int): + grads = grads[0] if unit_aware: args_to_grad = jax.tree.map(lambda i: args[i], argnums) - r_unit = get_unit(jax.tree.leaves(y, is_leaf=lambda x: isinstance(x, Quantity))[0]) + r_unit = get_unit(y) grads = jax.tree.map( lambda arg, grad: maybe_decimal( Quantity(get_mantissa(grad), unit=r_unit / get_unit(arg)) @@ -68,8 +70,6 @@ def grad_fun(*args, **kwargs): grads, is_leaf=lambda x: isinstance(x, Quantity) ) - if isinstance(argnums, int): - grads = grads[0] if has_aux: return (grads, y, aux) if return_value else (grads, aux) else: diff --git a/brainunit/autograd/_vector_grad_test.py b/brainunit/autograd/_vector_grad_test.py index 76790b4..e83f915 100644 --- a/brainunit/autograd/_vector_grad_test.py +++ b/brainunit/autograd/_vector_grad_test.py @@ -39,6 +39,21 @@ def simple_function(x): assert u.math.allclose(grad, jnp.array([6.0, 8.0]) * unit) +def test_vector_grad_simple2(): + def simple_function(x): + return x ** 3 + + x = jnp.array([3.0, 4.0]) + for unit in [None, u.ms, u.mvolt]: + vector_grad_fn = u.autograd.vector_grad(simple_function) + if unit is None: + grad = vector_grad_fn(x) + assert jnp.allclose(grad, 3 * x ** 2) + else: + grad = vector_grad_fn(x * unit) + assert u.math.allclose(grad, 3 * (x * unit) ** 2) + + def test_vector_grad_multiple_args(): def multi_arg_function(x, y): return x * y