Skip to content

Commit

Permalink
upgrade
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 22, 2024
1 parent 9d9cd01 commit ad47ce8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 37 deletions.
12 changes: 5 additions & 7 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from .tools import (
eval_shape,
eval_shape_with_context,
dynvar_deprecation,
node_deprecation,
abstract
Expand Down Expand Up @@ -540,8 +539,8 @@ def cond(
dyn_vars = get_stack_cache((true_fun, false_fun))
if not jax.config.jax_disable_jit and dyn_vars is None:
with VariableStack() as dyn_vars:
rets = eval_shape_with_context(true_fun, *operands)
_ = eval_shape_with_context(false_fun, *operands)
rets = eval_shape(true_fun, *operands, with_stack=True)[1]
_ = eval_shape(false_fun, *operands, with_stack=True)
cache_stack((true_fun, false_fun), dyn_vars)
if not dyn_vars.is_first_stack():
return rets
Expand Down Expand Up @@ -676,7 +675,7 @@ def ifelse(
dyn_vars = get_stack_cache(tuple(branches))
if dyn_vars is None:
with VariableStack() as dyn_vars:
rets = [eval_shape_with_context(fun, *operands) for fun in branches]
rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches]
trees = [jax.tree_util.tree_structure(ret) for ret in rets]
if not _all_equal(trees):
msg = 'All returns in branches should have the same tree structure. But we got:\n'
Expand Down Expand Up @@ -1109,7 +1108,6 @@ def while_loop(
No longer need to provide ``child_objs``. This function is capable of automatically
collecting the children objects used in the target ``func``.
"""
dynvar_deprecation(dyn_vars)
node_deprecation(child_objs)
Expand All @@ -1120,8 +1118,8 @@ def while_loop(
stack = get_stack_cache((body_fun, cond_fun))
if not jax.config.jax_disable_jit and stack is None:
with VariableStack() as stack:
_ = eval_shape_with_context(cond_fun, *operands)
rets = eval_shape_with_context(body_fun, *operands)
_ = eval_shape(cond_fun, *operands, with_stack=True)
rets = eval_shape(body_fun, *operands, with_stack=True)[1]
cache_stack((body_fun, cond_fun), stack)
if not stack.is_first_stack():
return rets
Expand Down
44 changes: 14 additions & 30 deletions brainpy/_src/math/object_transform/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def eval_shape(
*args,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
with_stack: bool = False,
**kwargs
):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
Expand All @@ -189,6 +190,7 @@ def eval_shape(
fun: The callable function.
*args: The positional arguments.
**kwargs: The keyword arguments.
with_stack: Whether evaluate the function within a local variable stack.
static_argnums: The static argument indices.
static_argnames: The static argument names.
Expand All @@ -204,40 +206,22 @@ def eval_shape(
# evaluate the function
fun_in_eval_shape.append(fun)
try:
if len(fun_in_eval_shape) > 1:
returns = f2(*args, **kwargs)
if with_stack:
with VariableStack() as stack:
if len(fun_in_eval_shape) > 1:
returns = f2(*args, **kwargs)
else:
returns = jax.eval_shape(f2, *args, **kwargs)
else:
returns = jax.eval_shape(f2, *args, **kwargs)
stack = None
if len(fun_in_eval_shape) > 1:
returns = f2(*args, **kwargs)
else:
returns = jax.eval_shape(f2, *args, **kwargs)
finally:
fun_in_eval_shape.pop()
del f2
return returns


def eval_shape_with_context(
fun: Callable,
*args,
static_argnums: Sequence[int] = (),
static_argnames: Sequence[str] = (),
return_context: bool = False,
**kwargs
):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
Args:
fun: The callable function.
*args: The positional arguments.
**kwargs: The keyword arguments.
static_argnums: The static argument indices.
static_argnames: The static argument names.
return_context: Whether to return the variable stack.
Returns:
The variable stack and the functional returns.
"""
with VariableStack() as stack:
returns = eval_shape(fun, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
if return_context:
if with_stack:
return stack, returns
else:
return returns
Expand Down

0 comments on commit ad47ce8

Please sign in to comment.