diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index 2e5e103c..ee0511e8 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -213,16 +213,17 @@ def __call__(self, *args, **kwargs): self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) self._eval_dyn_vars = True - # if not the outermost transformation - if not stack.is_first_stack(): - return self._return(rets) - rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients self._dyn_vars.dict_data(), # dynamical variables *args, **kwargs ) + + # if not the outermost transformation + if not stack.is_first_stack(): + return self._return(rets) + return self._return(rets) diff --git a/brainpy/_src/math/object_transform/tests/test_autograd.py b/brainpy/_src/math/object_transform/tests/test_autograd.py index 1cd7c7cd..bb4adf1d 100644 --- a/brainpy/_src/math/object_transform/tests/test_autograd.py +++ b/brainpy/_src/math/object_transform/tests/test_autograd.py @@ -86,6 +86,17 @@ def call(a, b, c): assert aux[1] == bm.exp(0.1) + def test_grad_jit(self): + def call(a, b, c): return bm.sum(a + b + c) + + bm.random.seed(1) + a = bm.ones(10) + b = bm.random.randn(10) + c = bm.random.uniform(size=10) + f_grad = bm.jit(bm.grad(call)) + assert (f_grad(a, b, c) == 1.).all() + + class TestObjectFuncGrad(unittest.TestCase): def test_grad_ob1(self): class Test(bp.BrainPyObject):