Skip to content

Commit

Permalink
update JIT transform
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 3, 2024
1 parent b0c46c4 commit c12840f
Showing 1 changed file with 27 additions and 25 deletions.
52 changes: 27 additions & 25 deletions brainpy/_src/math/object_transform/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ def _seq_of_str(static_argnames):
return static_argnames


def _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs):
# call the transformed function
rng_keys = stack.call_on_subset(_is_rng, _rng_split_key)
changes, out = transform(stack.dict_data(), *args, **kwargs)
for key, v in changes.items():
stack[key]._value = v
for key, v in rng_keys.items():
stack[key]._value = v
return out


class JITTransform(ObjectTransform):
"""Object-oriented JIT transformation in BrainPy."""

Expand Down Expand Up @@ -134,13 +145,13 @@ def __init__(
# OO transformation parameters
self._transform = None
self._dyn_vars = None

def _transform_function(self, variable_data: Dict, *args, **kwargs):
for key, v in self._dyn_vars.items():
v._value = variable_data[key]
out = self.fun(*args, **kwargs)
changes = self._dyn_vars.dict_data_of_subset(_is_not_rng)
return changes, out
#
# def _transform_function(self, variable_data: Dict, *args, **kwargs):
# for key, v in self._dyn_vars.items():
# v._value = variable_data[key]
# out = self.fun(*args, **kwargs)
# changes = self._dyn_vars.dict_data_of_subset(_is_not_rng)
# return changes, out

def _get_transform(self, *args, **kwargs):
with VariableStack() as self._dyn_vars:
Expand Down Expand Up @@ -176,7 +187,7 @@ def _get_transform(self, *args, **kwargs):

# jit
self._transform = jax.jit(
self._transform_function,
_make_transform(self.fun, self._dyn_vars),
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums),
static_argnames=self._static_argnames,
donate_argnums=self._donate_argnums,
Expand All @@ -199,13 +210,7 @@ def __call__(self, *args, **kwargs):
return rets

# call the transformed function
rng_keys = self._dyn_vars.call_on_subset(_is_rng, _rng_split_key)
changes, out = self._transform(self._dyn_vars.dict_data(), *args, **kwargs)
for key, v in changes.items():
self._dyn_vars[key]._value = v
for key, v in rng_keys.items():
self._dyn_vars[key]._value = v
return out
return _jit_call_take_care_of_rngs(self._transform, self._dyn_vars, *args, **kwargs)

def __repr__(self):
name = self.__class__.__name__
Expand Down Expand Up @@ -466,7 +471,7 @@ def call_fun(self, *args, **kwargs):
if cache is None:
fun2 = partial(fun, self)
with VariableStack() as stack:
_ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
out = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames)
_transform = jax.jit(
_make_transform(fun2, stack),
static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums),
Expand All @@ -478,25 +483,22 @@ def call_fun(self, *args, **kwargs):
**jit_kwargs
)
cache_stack(hash_v, (stack, _transform)) # cache "variable stack" and "transform function"

if not stack.is_first_stack():
return out
else:
stack, _transform = cache
del cache
out, changes = _transform(stack.dict_data(), *args, **kwargs)
for key, v in stack.items():
v._value = changes[key]
return out
return _jit_call_take_care_of_rngs(_transform, stack, *args, **kwargs)

return call_fun


def _make_transform(fun, stack):
@wraps(fun)
def _transform_function(variable_data: dict, *args, **kwargs):
def _transform_function(variable_data: Dict, *args, **kwargs):
for key, v in stack.items():
v._value = variable_data[key]
out = fun(*args, **kwargs)
changes = stack.dict_data()
return out, changes
changes = stack.dict_data_of_subset(_is_not_rng)
return changes, out

return _transform_function

0 comments on commit c12840f

Please sign in to comment.