From ad47ce8f977e1167e067009127463e12f5b431ad Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 13:05:58 +0800 Subject: [PATCH] upgrade --- .../_src/math/object_transform/controls.py | 12 +++-- brainpy/_src/math/object_transform/tools.py | 44 ++++++------------- 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 286a7891..3edeb08e 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -22,7 +22,6 @@ ) from .tools import ( eval_shape, - eval_shape_with_context, dynvar_deprecation, node_deprecation, abstract @@ -540,8 +539,8 @@ def cond( dyn_vars = get_stack_cache((true_fun, false_fun)) if not jax.config.jax_disable_jit and dyn_vars is None: with VariableStack() as dyn_vars: - rets = eval_shape_with_context(true_fun, *operands) - _ = eval_shape_with_context(false_fun, *operands) + rets = eval_shape(true_fun, *operands, with_stack=True)[1] + _ = eval_shape(false_fun, *operands, with_stack=True) cache_stack((true_fun, false_fun), dyn_vars) if not dyn_vars.is_first_stack(): return rets @@ -676,7 +675,7 @@ def ifelse( dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: with VariableStack() as dyn_vars: - rets = [eval_shape_with_context(fun, *operands) for fun in branches] + rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] trees = [jax.tree_util.tree_structure(ret) for ret in rets] if not _all_equal(trees): msg = 'All returns in branches should have the same tree structure. But we got:\n' @@ -1109,7 +1108,6 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. - """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1120,8 +1118,8 @@ def while_loop( stack = get_stack_cache((body_fun, cond_fun)) if not jax.config.jax_disable_jit and stack is None: with VariableStack() as stack: - _ = eval_shape_with_context(cond_fun, *operands) - rets = eval_shape_with_context(body_fun, *operands) + _ = eval_shape(cond_fun, *operands, with_stack=True) + rets = eval_shape(body_fun, *operands, with_stack=True)[1] cache_stack((body_fun, cond_fun), stack) if not stack.is_first_stack(): return rets diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 7057d047..632c6d79 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -181,6 +181,7 @@ def eval_shape( *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), + with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. @@ -189,6 +190,7 @@ def eval_shape( fun: The callable function. *args: The positional arguments. **kwargs: The keyword arguments. + with_stack: Whether evaluate the function within a local variable stack. static_argnums: The static argument indices. static_argnames: The static argument names. @@ -204,40 +206,22 @@ def eval_shape( # evaluate the function fun_in_eval_shape.append(fun) try: - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) + if with_stack: + with VariableStack() as stack: + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) else: - returns = jax.eval_shape(f2, *args, **kwargs) + stack = None + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) finally: fun_in_eval_shape.pop() del f2 - return returns - - -def eval_shape_with_context( - fun: Callable, - *args, - static_argnums: Sequence[int] = (), - static_argnames: Sequence[str] = (), - return_context: bool = False, - **kwargs -): - """Compute the shape/dtype of ``fun`` without any FLOPs. - - Args: - fun: The callable function. - *args: The positional arguments. - **kwargs: The keyword arguments. - static_argnums: The static argument indices. - static_argnames: The static argument names. - return_context: Whether to return the variable stack. - - Returns: - The variable stack and the functional returns. - """ - with VariableStack() as stack: - returns = eval_shape(fun, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) - if return_context: + if with_stack: return stack, returns else: return returns