Skip to content

Commit

Permalink
update transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Feb 21, 2024
1 parent d336694 commit d703512
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 31 deletions.
11 changes: 6 additions & 5 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,12 @@ def __call__(self, *args, **kwargs):
stack = get_stack_cache(self.target)
if stack is None:
with new_transform(self):
stack, rets = eval_shape(self._transform,
[v.value for v in self._grad_vars], # variables for gradients
self._dyn_vars.dict_data(), # dynamical variables
*args,
**kwargs)
with VariableStack() as stack:
rets = eval_shape(self._transform,
[v.value for v in self._grad_vars], # variables for gradients
{}, # dynamical variables
*args,
**kwargs)
cache_stack(self.target, stack)

self._dyn_vars = stack
Expand Down
6 changes: 4 additions & 2 deletions brainpy/_src/math/object_transform/controls.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,8 @@ def for_loop(
with new_transform('for_loop'):
transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar,
remat, reverse, unroll, unroll_kwargs)
dyn_vars, rets = eval_shape(transform, operands)
with VariableStack() as dyn_vars:
rets = eval_shape(transform, operands)
cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache
if current_transform_number():
return rets[1]
Expand Down Expand Up @@ -1006,7 +1007,8 @@ def scan(
if dyn_vars is None:
with new_transform('scan'):
transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll)
dyn_vars, rets = eval_shape(transform, init, operands)
with VariableStack() as dyn_vars:
rets = eval_shape(transform, init, operands)
cache_stack(body_fun, dyn_vars) # cache
if current_transform_number():
return rets[0][1], rets[1]
Expand Down
15 changes: 9 additions & 6 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
node_deprecation,
eval_shape)
from .variables import (Variable,
VariableStack,
current_transform_number,
new_transform)
from ..ndarray import Array
Expand Down Expand Up @@ -146,11 +147,12 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs):

def _get_transform(self, *args, **kwargs):
with new_transform(self):
self._dyn_vars, rets = eval_shape(self.fun,
*args,
static_argnums=self._static_argnums,
static_argnames=self._static_argnames,
**kwargs)
with VariableStack() as self._dyn_vars:
rets = eval_shape(self.fun,
*args,
static_argnums=self._static_argnums,
static_argnames=self._static_argnames,
**kwargs)
# in_shardings
if self._in_shardings is None:
in_shardings = None
Expand Down Expand Up @@ -467,7 +469,8 @@ def call_fun(self, *args, **kwargs):
cache = get_stack_cache(hash_v) # TODO: better cache mechanism
if cache is None:
fun2 = partial(fun, self)
stack, _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
with VariableStack() as stack:
_ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
_transform = jax.jit(
_make_transform(fun2, stack),
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums),
Expand Down
34 changes: 16 additions & 18 deletions brainpy/_src/math/object_transform/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def _partial_fun2(

@wraps(fun)
def new_fun(*dynargs, **dynkwargs):
return fun(*[dynargs[dyn_arg_ids[i]] if i in dyn_arg_ids else static_args[i]
for i in range(num_args)],
**static_kwargs, **dynkwargs)
return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)],
**static_kwargs,
**dynkwargs)

return new_fun, dyn_args, dyn_kwargs

Expand All @@ -197,23 +197,21 @@ def eval_shape(
"""
# reorganize the function
if len(static_argnums) or len(static_argnames):
f2, args, kwargs = _partial_fun2(fun, args, kwargs,
static_argnums=static_argnums,
static_argnames=static_argnames)
f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
else:
f2, args, kwargs = fun, args, kwargs
f2 = fun

# evaluate the function
fun_in_eval_shape.append(fun)
try:
with VariableStack() as stack:
if len(fun_in_eval_shape) > 1:
returns = f2(*args, **kwargs)
else:
returns = jax.eval_shape(f2, *args, **kwargs)
if len(fun_in_eval_shape) > 1:
returns = f2(*args, **kwargs)
else:
returns = jax.eval_shape(f2, *args, **kwargs)
pass
finally:
fun_in_eval_shape.pop()
return stack, returns
return returns


def eval_shape_of_multi_funcs(
Expand All @@ -226,7 +224,7 @@ def eval_shape_of_multi_funcs(
"""Compute the shape/dtype of ``funs`` without any FLOPs.
Args:
fun: The callable function.
funs: A set of callable functions.
*args: The positional arguments.
**kwargs: The keyword arguments.
static_argnums: The static argument indices.
Expand All @@ -235,9 +233,9 @@ def eval_shape_of_multi_funcs(
Returns:
The variable stack and the functional returns.
"""
stack, returns = VariableStack(), []
for fun in funs:
st, ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs)
stack += st
returns = []
with VariableStack() as stack:
for fun in funs:
ret = eval_shape(fun, *args, static_argnums=static_argnums, static_argnames=static_argnames, **kwargs)
returns.append(ret)
return stack, returns

0 comments on commit d703512

Please sign in to comment.