Skip to content

Commit

Permalink
fix autograd bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 13, 2024
1 parent d1a4afb commit 79efb61
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
9 changes: 5 additions & 4 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,17 @@ def __call__(self, *args, **kwargs):
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True

# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)

# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

return self._return(rets)


Expand Down
11 changes: 11 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ def call(a, b, c):
assert aux[1] == bm.exp(0.1)


def test_grad_jit(self):
def call(a, b, c): return bm.sum(a + b + c)

bm.random.seed(1)
a = bm.ones(10)
b = bm.random.randn(10)
c = bm.random.uniform(size=10)
f_grad = bm.jit(bm.grad(call))
assert (f_grad(a, b, c) == 1.).all()


class TestObjectFuncGrad(unittest.TestCase):
def test_grad_ob1(self):
class Test(bp.BrainPyObject):
Expand Down

0 comments on commit 79efb61

Please sign in to comment.