diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index b868b807..baa1f760 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -205,11 +205,12 @@ def __call__(self, *args, **kwargs): stack = get_stack_cache(self.target) if stack is None: with new_transform(self): - stack, rets = eval_shape(self._transform, - [v.value for v in self._grad_vars], # variables for gradients - self._dyn_vars.dict_data(), # dynamical variables - *args, - **kwargs) + with VariableStack() as stack: + rets = eval_shape(self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs) cache_stack(self.target, stack) self._dyn_vars = stack diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 2f43ec42..6ce4210b 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -882,7 +882,8 @@ def for_loop( with new_transform('for_loop'): transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll, unroll_kwargs) - dyn_vars, rets = eval_shape(transform, operands) + with VariableStack() as dyn_vars: + rets = eval_shape(transform, operands) cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache if current_transform_number(): return rets[1] @@ -1006,7 +1007,8 @@ def scan( if dyn_vars is None: with new_transform('scan'): transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - dyn_vars, rets = eval_shape(transform, init, operands) + with VariableStack() as dyn_vars: + rets = eval_shape(transform, init, operands) cache_stack(body_fun, dyn_vars) # cache if current_transform_number(): return rets[0][1], rets[1] diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 4ad2e250..3965c1a7 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -19,6 +19,7 @@ node_deprecation, eval_shape) from .variables import (Variable, + VariableStack, current_transform_number, new_transform) from ..ndarray import Array @@ -146,11 +147,12 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): def _get_transform(self, *args, **kwargs): with new_transform(self): - self._dyn_vars, rets = eval_shape(self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - **kwargs) + with VariableStack() as self._dyn_vars: + rets = eval_shape(self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + **kwargs) # in_shardings if self._in_shardings is None: in_shardings = None @@ -467,7 +469,8 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - stack, _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + with VariableStack() as stack: + _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index e5b13735..6010fac7 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -169,9 +169,9 @@ def _partial_fun2( @wraps(fun) def new_fun(*dynargs, **dynkwargs): - return fun(*[dynargs[dyn_arg_ids[i]] if i in dyn_arg_ids else static_args[i] - for i in range(num_args)], - **static_kwargs, **dynkwargs) + return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], + **static_kwargs, + **dynkwargs) return new_fun, dyn_args, dyn_kwargs @@ -197,23 +197,21 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun2(fun, args, kwargs, - static_argnums=static_argnums, - static_argnames=static_argnames) + f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) else: - f2, args, kwargs = fun, args, kwargs + f2 = fun # evaluate the function fun_in_eval_shape.append(fun) try: - with VariableStack() as stack: - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) - else: - returns = jax.eval_shape(f2, *args, **kwargs) + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) + pass finally: fun_in_eval_shape.pop() - return stack, returns + return returns def eval_shape_of_multi_funcs( @@ -226,7 +224,7 @@ def eval_shape_of_multi_funcs( """Compute the shape/dtype of ``funs`` without any FLOPs. Args: - fun: The callable function. + funs: A set of callable functions. *args: The positional arguments. **kwargs: The keyword arguments. static_argnums: The static argument indices. @@ -235,9 +233,9 @@ def eval_shape_of_multi_funcs( Returns: The variable stack and the functional returns. """ - stack, returns = VariableStack(), [] - for fun in funs: - st, ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs) - stack += st + returns = [] + with VariableStack() as stack: + for fun in funs: + ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs) returns.append(ret) return stack, returns