Skip to content

Commit

Permalink
fix autograd bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 18, 2024
1 parent 79efb61 commit 0a1fa85
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __init__(
self.target = target

# transform
self._eval_dyn_vars = False
self._grad_transform = transform
self._dyn_vars = VariableStack()
self._transform = None
Expand Down Expand Up @@ -198,32 +197,30 @@ def __call__(self, *args, **kwargs):
)
return self._return(rets)

elif not self._eval_dyn_vars: # evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with VariableStack() as stack:
rets = eval_shape(self._transform,
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs)
cache_stack(self.target, stack)
# evaluate dynamical variables
stack = get_stack_cache(self.target)
if stack is None:
with VariableStack() as stack:
rets = eval_shape(self._transform,
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs)
cache_stack(self.target, stack)

self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])
self._eval_dyn_vars = True
self._dyn_vars = stack
self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars])

# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

rets = self._transform(
[v.value for v in self._grad_vars], # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs
)

# if not the outermost transformation
if not stack.is_first_stack():
return self._return(rets)

return self._return(rets)


Expand Down

0 comments on commit 0a1fa85

Please sign in to comment.