Skip to content

Commit

Permalink
fix vector_grad bugs (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 authored Nov 25, 2024
1 parent a80a683 commit f90cba9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 3 deletions.
6 changes: 3 additions & 3 deletions brainunit/autograd/_vector_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,11 @@ def grad_fun(*args, **kwargs):
assert len(leaves) == 1, 'The function must return a single array when unit_aware is True.'
tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves])
grads = vjp_fn(tangents)
if isinstance(argnums, int):
grads = grads[0]
if unit_aware:
args_to_grad = jax.tree.map(lambda i: args[i], argnums)
r_unit = get_unit(jax.tree.leaves(y, is_leaf=lambda x: isinstance(x, Quantity))[0])
r_unit = get_unit(y)
grads = jax.tree.map(
lambda arg, grad: maybe_decimal(
Quantity(get_mantissa(grad), unit=r_unit / get_unit(arg))
Expand All @@ -68,8 +70,6 @@ def grad_fun(*args, **kwargs):
grads,
is_leaf=lambda x: isinstance(x, Quantity)
)
if isinstance(argnums, int):
grads = grads[0]
if has_aux:
return (grads, y, aux) if return_value else (grads, aux)
else:
Expand Down
15 changes: 15 additions & 0 deletions brainunit/autograd/_vector_grad_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ def simple_function(x):
assert u.math.allclose(grad, jnp.array([6.0, 8.0]) * unit)


def test_vector_grad_simple2():
def simple_function(x):
return x ** 3

x = jnp.array([3.0, 4.0])
for unit in [None, u.ms, u.mvolt]:
vector_grad_fn = u.autograd.vector_grad(simple_function)
if unit is None:
grad = vector_grad_fn(x)
assert jnp.allclose(grad, 3 * x ** 2)
else:
grad = vector_grad_fn(x * unit)
assert u.math.allclose(grad, 3 * (x * unit) ** 2)


def test_vector_grad_multiple_args():
def multi_arg_function(x, y):
return x * y
Expand Down

0 comments on commit f90cba9

Please sign in to comment.