From 165b9eef22190ac169e36d0304408097cdd0dff0 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 3 Mar 2024 14:23:15 +0800 Subject: [PATCH] Find back updates (#646) * roll back previous updates * update JIT transform * Update * Update delayvars.py * remove using internal API of jax for registering pytree objects --------- Co-authored-by: He Sichao <1310722434@qq.com> --- brainpy/_src/dnn/conv.py | 11 +- brainpy/_src/dnn/tests/test_activation.py | 2 +- brainpy/_src/dnn/tests/test_conv_layers.py | 11 +- brainpy/_src/dnn/tests/test_function.py | 6 +- brainpy/_src/dnn/tests/test_normalization.py | 3 +- brainpy/_src/dnn/tests/test_pooling_layers.py | 2 +- brainpy/_src/math/delayvars.py | 5 +- .../_src/math/object_transform/autograd.py | 45 +- brainpy/_src/math/object_transform/base.py | 6 +- .../_src/math/object_transform/controls.py | 136 +++--- brainpy/_src/math/object_transform/jit.py | 129 +++-- brainpy/_src/math/object_transform/naming.py | 10 +- .../_src/math/object_transform/parallels.py | 460 ------------------ brainpy/_src/math/object_transform/tools.py | 75 ++- .../_src/math/object_transform/variables.py | 45 +- brainpy/_src/tools/functions.py | 192 ++++++++ brainpy/_src/tools/tests/test_functions.py | 24 + brainpy/math/compat_pytorch.py | 2 +- brainpy/math/oo_transform.py | 4 +- brainpy/tools.py | 4 + docs/advanced_tutorials.rst | 51 +- docs/apis/brainpy.math.oo_transform.rst | 1 + docs/toolboxes.rst | 38 +- docs/tutorials.rst | 77 ++- 24 files changed, 610 insertions(+), 729 deletions(-) delete mode 100644 brainpy/_src/math/object_transform/parallels.py create mode 100644 brainpy/_src/tools/functions.py create mode 100644 brainpy/_src/tools/tests/test_functions.py diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index deead1f3b..e4b6e25d2 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -160,7 +160,7 @@ def update(self, x): nonbatching = False if x.ndim == self.num_spatial_dims + 1: nonbatching = True - x = x.unsqueeze(0) + x = bm.unsqueeze(x, 0) w = self.w.value if self.mask is not None: try: @@ -190,6 +190,9 @@ def __repr__(self): class Conv1d(_GeneralConv): """One-dimensional convolution. + The input should a 2d array with the shape of ``[H, C]``, or + a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. + Parameters ---------- in_channels: int @@ -282,6 +285,9 @@ def _check_input_dim(self, x): class Conv2d(_GeneralConv): """Two-dimensional convolution. + The input should a 3d array with the shape of ``[H, W, C]``, or + a 4d array with the shape of ``[B, H, W, C]``. + Parameters ---------- in_channels: int @@ -375,6 +381,9 @@ def _check_input_dim(self, x): class Conv3d(_GeneralConv): """Three-dimensional convolution. + The input should a 3d array with the shape of ``[H, W, D, C]``, or + a 4d array with the shape of ``[B, H, W, D, C]``. + Parameters ---------- in_channels: int diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index ba2a49efd..7a0fa57af 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,5 +1,5 @@ -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 3c9fdfa87..05f523622 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,17 +1,15 @@ # -*- coding: utf-8 -*- -from unittest import TestCase -from absl.testing import absltest import jax.numpy as jnp -import brainpy.math as bm +from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm class TestConv(parameterized.TestCase): def test_Conv2D_img(self): - bm.random.seed() img = jnp.zeros((2, 200, 198, 4)) for k in range(4): x = 30 + 60 * k @@ -24,6 +22,7 @@ def test_Conv2D_img(self): strides=(2, 1), padding='VALID', groups=4) out = net(img) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 99, 196, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(img)[0, :, :, 0]) @@ -31,7 +30,6 @@ def test_Conv2D_img(self): bm.clear_buffer_memory() def test_conv1D(self): - bm.random.seed() with bp.math.training_environment(): model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) @@ -39,6 +37,7 @@ def test_conv1D(self): out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :]) @@ -54,6 +53,7 @@ def test_conv2D(self): out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :, 31]) @@ -67,6 +67,7 @@ def test_conv3D(self): input = bp.math.ones((2, 5, 5, 5, 3)) out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 5, 32)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 269fec441..9ad15938d 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -1,12 +1,10 @@ # -*- coding: utf-8 -*- -from unittest import TestCase - -import jax.numpy as jnp -import brainpy.math as bm from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class TestFunction(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index fdc5b34e3..e76b3616b 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,7 +1,8 @@ -import brainpy.math as bm from absl.testing import parameterized from absl.testing import absltest + import brainpy as bp +import brainpy.math as bm class Test_Normalization(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 34f8f5cd5..5748edd8b 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index eb8e27c8f..676e4286b 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -11,7 +11,7 @@ from brainpy import check from brainpy.check import is_float, is_integer, jit_error from brainpy.errors import UnsupportedError -from .compat_numpy import vstack, broadcast_to +from .compat_numpy import broadcast_to, expand_dims, concatenate from .environment import get_dt, get_float from .interoperability import as_jax from .ndarray import ndarray, Array @@ -392,6 +392,7 @@ def reset( dtype=delay_target.dtype), batch_axis=batch_axis) else: + self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) @@ -472,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) + self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) else: self.data[:] = value diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f5e091675..ad8a5ccf6 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -28,10 +28,8 @@ get_stack_cache, cache_stack) from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, - VariableStack, - current_transform_number, - new_transform) +from .variables import (Variable, VariableStack) +from .tools import eval_shape __all__ = [ 'grad', # gradient of scalar function @@ -203,36 +201,21 @@ def __call__(self, *args, **kwargs): elif not self._eval_dyn_vars: # evaluate dynamical variables stack = get_stack_cache(self.target) if stack is None: - with new_transform(self): - with VariableStack() as stack: - if current_transform_number() > 1: - rets = self._transform( - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs - ) - else: - rets = jax.eval_shape( - self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs - ) + with VariableStack() as stack: + rets = eval_shape(self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if current_transform_number(): - return self._return(rets) - else: - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + # if not the outermost transformation + if not stack.is_first_stack(): + return self._return(rets) rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index b21ed2af3..936f62386 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -12,7 +12,6 @@ import jax import numpy as np -from jax._src.tree_util import _registry from jax.tree_util import register_pytree_node_class from brainpy._src.math.modes import Mode @@ -27,6 +26,8 @@ variable_ = None StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) +registered = set() + __all__ = [ 'BrainPyObject', 'Base', 'FunAsObject', 'ObjectTransform', @@ -91,8 +92,9 @@ def __init__(self, name=None): super().__init__() if defaults.bp_object_as_pytree: - if self.__class__ not in _registry: + if self.__class__ not in registered: register_pytree_node_class(self.__class__) + registered.add(self.__class__) # check whether the object has a unique name. self._name = None diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 032a0fab6..3edeb08e8 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -21,17 +21,12 @@ cache_stack ) from .tools import ( - evaluate_dyn_vars, + eval_shape, dynvar_deprecation, node_deprecation, abstract ) -from .variables import ( - Variable, - VariableStack, - new_transform, - current_transform_number, -) +from .variables import (Variable, VariableStack) __all__ = [ 'make_loop', @@ -542,15 +537,13 @@ def cond( node_deprecation(child_objs) dyn_vars = get_stack_cache((true_fun, false_fun)) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('cond'): - dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars = dyn_vars1 + dyn_vars2 - cache_stack((true_fun, false_fun), dyn_vars) - if current_transform_number() > 0: - return rets + if not jax.config.jax_disable_jit and dyn_vars is None: + with VariableStack() as dyn_vars: + rets = eval_shape(true_fun, *operands, with_stack=True)[1] + _ = eval_shape(false_fun, *operands, with_stack=True) + cache_stack((true_fun, false_fun), dyn_vars) + if not dyn_vars.is_first_stack(): + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -681,20 +674,16 @@ def ifelse( else: dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: - with new_transform('ifelse'): - with VariableStack() as dyn_vars: - if current_transform_number() > 1: - rets = [branch(*operands) for branch in branches] - else: - rets = [jax.eval_shape(branch, *operands) for branch in branches] - trees = [jax.tree_util.tree_structure(ret) for ret in rets] - if not _all_equal(trees): - msg = 'All returns in branches should have the same tree structure. But we got:\n' - for tree in trees: - msg += f'- {tree}\n' - raise TypeError(msg) + with VariableStack() as dyn_vars: + rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] + trees = [jax.tree_util.tree_structure(ret) for ret in rets] + if not _all_equal(trees): + msg = 'All returns in branches should have the same tree structure. But we got:\n' + for tree in trees: + msg += f'- {tree}\n' + raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) - if current_transform_number(): + if not dyn_vars.is_first_stack(): return rets[0] branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches] @@ -880,28 +869,23 @@ def for_loop( if jit is None: # jax disable jit jit = not jax.config.jax_disable_jit - dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) + stack = get_stack_cache((body_fun, unroll_kwargs)) if jit: - if dyn_vars is None: + if stack is None: + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, + remat, reverse, unroll, unroll_kwargs) # TODO: better cache mechanism? - with new_transform('for_loop'): - with VariableStack() as dyn_vars: - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, - progress_bar, remat, reverse, unroll, - unroll_kwargs) - if current_transform_number() > 1: - rets = transform(operands) - else: - rets = jax.eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache - if current_transform_number(): + with VariableStack() as stack: + rets = eval_shape(transform, operands) + cache_stack((body_fun, unroll_kwargs), stack) # cache + if not stack.is_first_stack(): return rets[1] del rets else: - dyn_vars = VariableStack() + stack = VariableStack() # TODO: cache mechanism? - transform = _get_for_loop_transform(body_fun, dyn_vars, bar, + transform = _get_for_loop_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll, unroll_kwargs) if jit: @@ -909,11 +893,11 @@ def for_loop( else: with jax.disable_jit(): dyn_vals, out_vals = transform(operands) - for key in dyn_vars.keys(): - dyn_vars[key]._value = dyn_vals[key] + for key in stack.keys(): + stack[key]._value = dyn_vals[key] if progress_bar: bar.close() - del dyn_vals, dyn_vars + del dyn_vals, stack return out_vals @@ -1011,26 +995,21 @@ def scan( num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) bar = tqdm(total=num_total) - dyn_vars = get_stack_cache(body_fun) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('scan'): - with VariableStack() as dyn_vars: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - if current_transform_number() > 1: - rets = transform(init, operands) - else: - rets = jax.eval_shape(transform, init, operands) - cache_stack(body_fun, dyn_vars) # cache - if current_transform_number(): - return rets[0][1], rets[1] - del rets - - dyn_vars = VariableStack() if dyn_vars is None else dyn_vars - transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) + stack = get_stack_cache(body_fun) + if not jax.config.jax_disable_jit and stack is None: + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + with VariableStack() as stack: + rets = eval_shape(transform, init, operands) + cache_stack(body_fun, stack) # cache + if not stack.is_first_stack(): + return rets[0][1], rets[1] + del rets + + stack = VariableStack() if stack is None else stack + transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) (dyn_vals, carry), out_vals = transform(init, operands) - for key in dyn_vars.keys(): - dyn_vars[key]._value = dyn_vals[key] + for key in stack.keys(): + stack[key]._value = dyn_vals[key] if progress_bar: bar.close() return carry, out_vals @@ -1129,7 +1108,6 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. - """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1137,18 +1115,16 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - dyn_vars = get_stack_cache((body_fun, cond_fun)) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('while_loop'): - dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars = dyn_vars1 + dyn_vars2 - cache_stack((body_fun, cond_fun), dyn_vars) - if current_transform_number(): - return rets - dyn_vars = VariableStack() if dyn_vars is None else dyn_vars - dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) - for k, v in dyn_vars.items(): + stack = get_stack_cache((body_fun, cond_fun)) + if not jax.config.jax_disable_jit and stack is None: + with VariableStack() as stack: + _ = eval_shape(cond_fun, *operands, with_stack=True) + rets = eval_shape(body_fun, *operands, with_stack=True)[1] + cache_stack((body_fun, cond_fun), stack) + if not stack.is_first_stack(): + return rets + stack = VariableStack() if stack is None else stack + dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) + for k, v in stack.items(): v._value = dyn_values[k] return out diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 7bb36f4e2..551a0949c 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,23 +11,15 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax -from jax.sharding import Sharding from brainpy import tools, check -from .tools import (dynvar_deprecation, - node_deprecation, - evaluate_dyn_vars_with_cache, - evaluate_dyn_vars, - _partial_fun) from .base import BrainPyObject, ObjectTransform from .naming import get_stack_cache, cache_stack +from .tools import (dynvar_deprecation, + node_deprecation, + eval_shape) +from .variables import (Variable, VariableStack) from ..ndarray import Array -from .variables import (Variable, - VariableStack, - outermost_transform, - transform_stack, - current_transform_number, - new_transform) RandomState = None @@ -96,6 +88,17 @@ def _seq_of_str(static_argnames): return static_argnames +def _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs): + # call the transformed function + rng_keys = stack.call_on_subset(_is_rng, _rng_split_key) + changes, out = transform(stack.dict_data(), *args, **kwargs) + for key, v in changes.items(): + stack[key]._value = v + for key, v in rng_keys.items(): + stack[key]._value = v + return out + + class JITTransform(ObjectTransform): """Object-oriented JIT transformation in BrainPy.""" @@ -142,25 +145,21 @@ def __init__( # OO transformation parameters self._transform = None self._dyn_vars = None - - def _transform_function(self, variable_data: Dict, *args, **kwargs): - for key, v in self._dyn_vars.items(): - v._value = variable_data[key] - out = self.fun(*args, **kwargs) - changes = self._dyn_vars.dict_data_of_subset(_is_not_rng) - return changes, out + # + # def _transform_function(self, variable_data: Dict, *args, **kwargs): + # for key, v in self._dyn_vars.items(): + # v._value = variable_data[key] + # out = self.fun(*args, **kwargs) + # changes = self._dyn_vars.dict_data_of_subset(_is_not_rng) + # return changes, out def _get_transform(self, *args, **kwargs): - with new_transform(self): - self._dyn_vars, rets = evaluate_dyn_vars( - self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - use_eval_shape=current_transform_number() <= 1, - **kwargs - ) - + with VariableStack() as self._dyn_vars: + rets = eval_shape(self.fun, + *args, + **kwargs, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames) # in_shardings if self._in_shardings is None: in_shardings = None @@ -186,18 +185,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + _make_transform(self.fun, self._dyn_vars), + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs): @@ -207,17 +206,11 @@ def __call__(self, *args, **kwargs): if self._transform is None: # initialize the transformation rets = self._get_transform(*args, **kwargs) # if not the outermost transformation - if current_transform_number(): + if not self._dyn_vars.is_first_stack(): return rets # call the transformed function - rng_keys = self._dyn_vars.call_on_subset(_is_rng, _rng_split_key) - changes, out = self._transform(self._dyn_vars.dict_data(), *args, **kwargs) - for key, v in changes.items(): - self._dyn_vars[key]._value = v - for key, v in rng_keys.items(): - self._dyn_vars[key]._value = v - return out + return _jit_call_take_care_of_rngs(self._transform, self._dyn_vars, *args, **kwargs) def __repr__(self): name = self.__class__.__name__ @@ -314,7 +307,7 @@ def jit( Examples -------- - You can JIT any object in which all dynamical variables are defined as :py:class:`~.Variable`. + You can JIT any object in which all dynamical variables are defined as :py:class:`~.Variable`. >>> import brainpy as bp >>> class Hello(bp.BrainPyObject): @@ -401,12 +394,12 @@ def cls_jit( **kwargs ) -> Callable: """Just-in-time compile a function and then the jitted function as the bound method for a class. - + Examples -------- - + This transformation can be put on any class function. For example, - + >>> import brainpy as bp >>> import brainpy.math as bm >>> @@ -415,7 +408,7 @@ def cls_jit( >>> super(SomeProgram, self).__init__() >>> self.a = bm.zeros(2) >>> self.b = bm.Variable(bm.ones(2)) - >>> + >>> >>> @bm.cls_jit(inline=True) >>> def __call__(self, *args, **kwargs): >>> a = bm.random.uniform(size=2) @@ -424,7 +417,7 @@ def cls_jit( >>> >>> program = SomeProgram() >>> program() - + Parameters ---------- {jit_pars} @@ -477,15 +470,8 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - - with jax.ensure_compile_time_eval(): - if len(static_argnums) or len(static_argnames): - fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) - else: - args_, kwargs_, fun3 = args, kwargs, fun2 - with VariableStack() as stack: - _ = jax.eval_shape(fun3, *args_, **kwargs_) - del args_, kwargs_ + with VariableStack() as stack: + out = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), @@ -497,25 +483,22 @@ def call_fun(self, *args, **kwargs): **jit_kwargs ) cache_stack(hash_v, (stack, _transform)) # cache "variable stack" and "transform function" - + if not stack.is_first_stack(): + return out else: stack, _transform = cache - del cache - out, changes = _transform(stack.dict_data(), *args, **kwargs) - for key, v in stack.items(): - v._value = changes[key] - return out + return _jit_call_take_care_of_rngs(_transform, stack, *args, **kwargs) return call_fun def _make_transform(fun, stack): @wraps(fun) - def _transform_function(variable_data: dict, *args, **kwargs): + def _transform_function(variable_data: Dict, *args, **kwargs): for key, v in stack.items(): v._value = variable_data[key] out = fun(*args, **kwargs) - changes = stack.dict_data() - return out, changes + changes = stack.dict_data_of_subset(_is_not_rng) + return changes, out return _transform_function diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 6326929c4..1181e003b 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -import gc + import warnings from brainpy import errors @@ -11,7 +11,6 @@ _name2id = dict() _typed_names = {} -_fun2stack = dict() def check_name_uniqueness(name, obj): @@ -42,7 +41,7 @@ def get_unique_name(type_: str): return name -def clear_name_cache(ignore_warn=False): +def clear_name_cache(ignore_warn=True): """Clear the cached names.""" _name2id.clear() _typed_names.clear() @@ -50,14 +49,17 @@ def clear_name_cache(ignore_warn=False): warnings.warn(f'All named models and their ids are cleared.', UserWarning) +_fun2stack = dict() + + def cache_stack(func, stack): _fun2stack[func] = stack def clear_stack_cache(): + """Clear the cached stack.""" for k in tuple(_fun2stack.keys()): del _fun2stack[k] - gc.collect() def get_stack_cache(func): diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py deleted file mode 100644 index 1eddce048..000000000 --- a/brainpy/_src/math/object_transform/parallels.py +++ /dev/null @@ -1,460 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -The parallel compilation tools for JAX backend. - -1. Vectorize compilation is implemented by the 'vmap()' function -2. Parallel compilation is implemented by the 'pmap()' function - -""" - - -import functools - -import jax -import jax.numpy as jnp -import numpy as np -from jax.interpreters.partial_eval import DynamicJaxprTracer -from jax.interpreters.partial_eval import JaxprTracer -from jax.interpreters.pxla import ShardedDeviceArray - -try: - from jax.errors import UnexpectedTracerError -except ImportError: - from jax.core import UnexpectedTracerError - -from brainpy import errors -from brainpy._src.math.random import RandomState -from brainpy._src.math.ndarray import Array -from brainpy.tools import change_func_name -from .base import BrainPyObject, ArrayCollector - -__all__ = [ - 'vmap', - 'pmap', -] - - -def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, - batch_idx, axis_name, f_name=None): - @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) - def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): - nonbatched_vars.assign(nonbatched_data) - batched_vars.assign(batched_data) - out = func(*args, **kwargs) - nonbatched_changes = nonbatched_vars.dict() - batched_changes = batched_vars.dict() - return nonbatched_changes, batched_changes, out - - def call(*args, **kwargs): - n = args[batch_idx[0]].shape[batch_idx[1]] - nonbatched_data = nonbatched_vars.dict() - batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} - try: - out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) - except UnexpectedTracerError as e: - nonbatched_vars.assign(nonbatched_data) - batched_vars.assign(batched_data) - raise errors.JaxTracerError() from e - # for key, v in dyn_changes.items(): - # dyn_vars[key] = reduce_func(v) - # for key, v in rand_changes.items(): - # rand_vars[key] = reduce_func(v) - return out - - return change_func_name(name=f_name, f=call) if f_name else call - - -def vmap(func, dyn_vars=None, batched_vars=None, - in_axes=0, out_axes=0, axis_name=None, - reduce_func=None, auto_infer=False): - """Vectorization compilation for class objects. - - Vectorized compile a function or a module to run in parallel on a single device. - - Examples - -------- - - Parameters - ---------- - func : BrainPyObject, function, callable - The function or the module to compile. - dyn_vars : dict, sequence - batched_vars : dict - in_axes : optional, int, sequence of int - Specify which input array axes to map over. If each positional argument to - ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, - or a tuple of integers and Nones with length equal to the number of - positional arguments to ``obj_or_func``. An integer or ``None`` - indicates which array axis to map over for all arguments (with ``None`` - indicating not to map any axis), and a tuple indicates which axis to map - for each corresponding positional argument. Axis integers must be in the - range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of - dimensions (axes) of the corresponding input array. - - If the positional arguments to ``obj_or_func`` are container types, the - corresponding element of ``in_axes`` can itself be a matching container, - so that distinct array axes can be mapped for different container - elements. ``in_axes`` must be a container tree prefix of the positional - argument tuple passed to ``obj_or_func``. - - At least one positional argument must have ``in_axes`` not None. The sizes - of the mapped input axes for all mapped positional arguments must all be - equal. - - Arguments passed as keywords are always mapped over their leading axis - (i.e. axis index 0). - out_axes : optional, int, tuple/list/dict - Indicate where the mapped axis should appear in the output. All outputs - with a mapped axis must have a non-None ``out_axes`` specification. Axis - integers must be in the range ``[-ndim, ndim)`` for each output array, - where ``ndim`` is the number of dimensions (axes) of the array returned - by the :func:`vmap`-ed function, which is one more than the number of - dimensions (axes) of the corresponding array returned by ``obj_or_func``. - axis_name : optional - - Returns - ------- - obj_or_func : Any - Batched/vectorized version of ``obj_or_func`` with arguments that correspond to - those of ``obj_or_func``, but with extra array axes at positions indicated by - ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but - with extra array axes at positions indicated by ``out_axes``. - - """ - # if isinstance(func, DynamicalSystem): - # if len(func.steps): # DynamicalSystem has step functions - # - # # dynamical variables - # dyn_vars = (dyn_vars or func.vars().unique()) - # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector() - # for key, val in dyn_vars.items(): - # if isinstance(val, RandomState): - # rand_vars[key] = val - # else: - # dyn_vars[key] = val - # - # # in axes - # if in_axes is None: - # in_axes = {key: (None, 0) for key in func.steps.keys()} - # elif isinstance(in_axes, int): - # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()} - # elif isinstance(in_axes, (tuple, list)): - # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()} - # elif isinstance(in_axes, dict): - # keys = list(func.steps.keys()) - # if keys[0] not in in_axes: - # in_axes = {key: (None, 0, in_axes) for key in keys} - # else: - # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys} - # assert isinstance(in_axes, dict) - # - # # batch size index - # batch_idx = {} - # for key, axes in in_axes.items(): - # for i, axis in enumerate(axes[2:]): - # if axis is not None: - # batch_idx[key] = (i, axis) - # break - # else: - # raise ValueError(f'Found no batch axis: {axes}.') - # - # # out axes - # if out_axes is None: - # out_axes = {key: 0 for key in func.steps.keys()} - # elif isinstance(out_axes, int): - # out_axes = {key: out_axes for key in func.steps.keys()} - # elif isinstance(out_axes, (tuple, list)): - # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()} - # elif isinstance(out_axes, dict): - # keys = list(func.steps.keys()) - # if keys[0] not in out_axes: - # out_axes = {key: (out_axes, 0, 0) for key in keys} - # else: - # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys} - # assert isinstance(out_axes, dict) - # - # # reduce_func - # if reduce_func is None: - # reduce_func = lambda x: x.mean(axis=0) - # - # # vectorized map functions - # for key in func.steps.keys(): - # func.steps[key] = _make_vmap(func=func.steps[key], - # dyn_vars=dyn_vars, - # rand_vars=rand_vars, - # in_axes=in_axes[key], - # out_axes=out_axes[key], - # axis_name=axis_name, - # batch_idx=batch_idx[key], - # reduce_func=reduce_func, - # f_name=key) - # - # return func - - if callable(func): - if auto_infer: - if dyn_vars is not None: - dyn_vars = dyn_vars - elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation - dyn_vars = func.vars().unique() - elif hasattr(func, '__self__'): - if isinstance(func.__self__, BrainPyObject): - dyn_vars = func.__self__.vars().unique() - - if dyn_vars is None: - return jax.vmap(func, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name) - - else: - if isinstance(dyn_vars, Array): - dyn_vars = [dyn_vars] - if isinstance(dyn_vars, (tuple, list)): - dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} - assert isinstance(dyn_vars, dict) - - # dynamical variables - _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector() - for key, val in dyn_vars.items(): - if isinstance(val, RandomState): - _rand_vars[key] = val - else: - _dyn_vars[key] = val - - # in axes - if in_axes is None: - in_axes = (None, 0) - elif isinstance(in_axes, (int, dict)): - in_axes = (None, 0, in_axes) - elif isinstance(in_axes, (tuple, list)): - in_axes = (None, 0) + tuple(in_axes) - assert isinstance(in_axes, (tuple, list)) - - # batch size index - batch_idx = {} - for key, axes in batch_idx.items(): - for i, axis in enumerate(axes[2:]): - if axis is not None: - batch_idx[key] = (i, axis) - break - else: - raise ValueError(f'Found no batch axis: {axes}.') - - # out axes - if out_axes is None: - out_axes = 0 - elif isinstance(out_axes, (int, dict)): - out_axes = (out_axes, 0, 0) - elif isinstance(out_axes, (tuple, list)): - out_axes = tuple(out_axes) + (0, 0) - assert isinstance(out_axes, (list, tuple)) - - # reduce_func - if reduce_func is None: - reduce_func = lambda x: x.mean(axis=0) - - # jit function - return _make_vmap(func=func, - nonbatched_vars=_dyn_vars, - batched_vars=_rand_vars, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name, - batch_idx=batch_idx) - - else: - raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable ' - f'function, but we got {type(func)}.') - - -def _device_reshape(x): - """Reshape an input array in order to broadcast to multiple devices.""" - num_device = jax.local_device_count() - - if not hasattr(x, 'ndim'): - raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to ' - f'parallel, first convert it to a Array, for example np.float(0.5)') - if x.ndim == 0: - return np.broadcast_to(x, [num_device]) - if x.shape[0] % num_device != 0: - raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among ' - f'{num_device} devices, but does not go equally.') - return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:]) - - -def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0, - out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, - axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None): - @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, - static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, - backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) - def pmapped_func(dyn_data, rand_data, *args, **kwargs): - dyn_vars.assign(dyn_data) - rand_vars.assign(rand_data) - out = func(*args, **kwargs) - dyn_changes = dyn_vars.dict() - rand_changes = rand_vars.dict() - return out, dyn_changes, rand_changes - - def call(*args): - un_replicated = [k for k, v in dyn_vars.items() - if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))] - if len(un_replicated): - raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.' - f'did you forget to call xx.replicate() on them?') - _args = [] - for i, x in enumerate(args): - if i + 2 in static_broadcasted_argnums: - _args.append(x) - else: - _args.append(jax.tree_map(_device_reshape, [x])[0]) - dyn_data = dyn_vars.dict() - rand_data = rand_vars.dict() - output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args) - dyn_vars.assign(dyn_changes) - rand_vars.assign(rand_changes) - return jax.tree_map(reduce_func, output) - - return change_func_name(name=f_name, f=call) if f_name else call - - -def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), - devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, - reduce_func=None): - """Parallel compilation for class objects. - - Parallel compile a function or a module to run on multiple devices in parallel. - - Parameters - ---------- - func - axis_name - in_axes - out_axes - static_broadcasted_argnums - devices - backend - axis_size - donate_argnums - global_arg_shapes - - Returns - ------- - - - Examples - -------- - - - """ - - # if isinstance(func, DynamicalSystem): - # if len(func.steps): # DynamicalSystem has step functions - # - # # dynamical variables - # all_vars = (dyn_vars or func.vars().unique()) - # dyn_vars = ArrayCollector() - # rand_vars = ArrayCollector() - # for key, val in all_vars.items(): - # if isinstance(val, RandomState): - # rand_vars[key] = val - # else: - # dyn_vars[key] = val - # - # # reduce function - # if reduce_func is None: - # reduce_func = jnp.concatenate - # - # # static broadcast-ed arguments - # if static_broadcasted_argnums is None: - # static_broadcasted_argnums = () - # elif isinstance(static_broadcasted_argnums, int): - # static_broadcasted_argnums = (static_broadcasted_argnums + 2,) - # elif isinstance(static_broadcasted_argnums, (tuple, list)): - # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) - # assert isinstance(static_broadcasted_argnums, (tuple, list)) - # - # # jit functions - # for key in func.steps.keys(): - # step = func.steps[key] - # func.steps[key] = _make_pmap(dyn_vars=dyn_vars, - # rand_vars=rand_vars, - # func=step, - # axis_name=axis_name, - # in_axes=in_axes, - # out_axes=out_axes, - # static_broadcasted_argnums=static_broadcasted_argnums, - # devices=devices, - # backend=backend, - # axis_size=axis_size, - # donate_argnums=donate_argnums, - # global_arg_shapes=global_arg_shapes, - # reduce_func=reduce_func, - # f_name=key) - # return func - - if callable(func): - if dyn_vars is not None: - dyn_vars = dyn_vars - elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation - dyn_vars = func.vars().unique() - elif hasattr(func, '__self__'): - if isinstance(func.__self__, BrainPyObject): - dyn_vars = func.__self__.vars().unique() - - if dyn_vars is None: - return jax.pmap(func, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) - else: - # dynamical variables - dyn_vars = ArrayCollector() - rand_vars = ArrayCollector() - for key, val in dyn_vars.items(): - if isinstance(val, RandomState): - rand_vars[key] = val - else: - dyn_vars[key] = val - - # static broadcast-ed arguments - if static_broadcasted_argnums is None: - static_broadcasted_argnums = () - elif isinstance(static_broadcasted_argnums, int): - static_broadcasted_argnums = (static_broadcasted_argnums + 2,) - elif isinstance(static_broadcasted_argnums, (tuple, list)): - static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) - assert isinstance(static_broadcasted_argnums, (tuple, list)) - - # reduce function - if reduce_func is None: - reduce_func = jnp.concatenate - - # jit function - func.__call__ = _make_pmap(dyn_vars=dyn_vars, - rand_vars=rand_vars, - func=func, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes, - reduce_func=reduce_func) - return func - - else: - raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, ' - f'but we got {type(func)}.') diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 7b519590a..632c6d79e 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -132,19 +132,65 @@ def evaluate_dyn_vars_with_cache( return stack +def _partial_fun2( + fun: Callable, + args: tuple, + kwargs: dict, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = () +): + num_args = len(args) + + # arguments + static_args = dict() + dyn_args = [] + dyn_arg_ids = dict() + static_argnums = list(static_argnums) + dyn_i = 0 + for i in range(num_args): + if i in static_argnums: + static_argnums.remove(i) + static_args[i] = args[i] + else: + dyn_args.append(args[i]) + dyn_arg_ids[i] = dyn_i + dyn_i += 1 + if len(static_argnums) > 0: + raise ValueError(f"Invalid static_argnums: {static_argnums}") + + # keyword arguments + static_kwargs, dyn_kwargs = {}, {} + for k, arg in kwargs.items(): + if k in static_argnames: + static_kwargs[k] = arg + else: + dyn_kwargs[k] = arg + del args, kwargs, static_argnums, static_argnames + + @wraps(fun) + def new_fun(*dynargs, **dynkwargs): + return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], + **static_kwargs, + **dynkwargs) + + return new_fun, dyn_args, dyn_kwargs + + def eval_shape( fun: Callable, *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), + with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. Args: fun: The callable function. - *args: - **kwargs: + *args: The positional arguments. + **kwargs: The keyword arguments. + with_stack: Whether evaluate the function within a local variable stack. static_argnums: The static argument indices. static_argnames: The static argument names. @@ -153,21 +199,30 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun(fun, args, kwargs, - static_argnums=static_argnums, - static_argnames=static_argnames) + f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) else: - f2, args, kwargs = fun, args, kwargs + f2 = fun # evaluate the function fun_in_eval_shape.append(fun) try: - with jax.ensure_compile_time_eval(): + if with_stack: with VariableStack() as stack: if len(fun_in_eval_shape) > 1: - returns = fun(*args, **kwargs) + returns = f2(*args, **kwargs) else: - returns = jax.eval_shape(fun, *args, **kwargs) + returns = jax.eval_shape(f2, *args, **kwargs) + else: + stack = None + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) finally: fun_in_eval_shape.pop() - return stack, returns + del f2 + if with_stack: + return stack, returns + else: + return returns + diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index 5014da0bf..b7babae8d 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple import jax @@ -190,6 +189,14 @@ def remove_by_id(self, *ids, error_when_absent=False): remove_var_by_id = remove_by_id + @classmethod + def num_of_stack(self): + return len(var_stack_list) + + @classmethod + def is_first_stack(self): + return len(var_stack_list) == 0 + def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) @@ -210,42 +217,6 @@ def __add__(self, other: dict): var_stack_list: List[VariableStack] = [] -transform_stack: List[Callable] = [] - - -@contextmanager -def new_transform(transform: Any): - transform_stack.append(transform) - try: - yield - finally: - transform_stack.pop() - - -def outermost_stack(): - if len(var_stack_list): - return var_stack_list[0] - else: - return None - - -def outermost_transform(): - if len(transform_stack): - return transform_stack[0] - else: - return None - - -def current_transform_number(): - return len(transform_stack) - - -def _stack_add_read(var: 'Variable'): - pass - - -def _stack_add_write(var: 'Variable'): - pass @register_pytree_node_class diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py new file mode 100644 index 000000000..cbc710dba --- /dev/null +++ b/brainpy/_src/tools/functions.py @@ -0,0 +1,192 @@ +import inspect +from functools import partial +from operator import attrgetter +from types import MethodType + +__all__ = [ + 'compose', 'pipe' +] + + +def identity(x): + """ Identity function. Return x + + >>> identity(3) + 3 + """ + return x + + +def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): + """ Like @property, but returns ``classval`` when used as a class attribute + + >>> class MyClass(object): + ... '''The class docstring''' + ... @instanceproperty(classval=__doc__) + ... def __doc__(self): + ... return 'An object docstring' + ... @instanceproperty + ... def val(self): + ... return 42 + ... + >>> MyClass.__doc__ + 'The class docstring' + >>> MyClass.val is None + True + >>> obj = MyClass() + >>> obj.__doc__ + 'An object docstring' + >>> obj.val + 42 + """ + if fget is None: + return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, + classval=classval) + return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, + classval=classval) + + +class InstanceProperty(property): + """ Like @property, but returns ``classval`` when used as a class attribute + + Should not be used directly. Use ``instanceproperty`` instead. + """ + + def __init__(self, fget=None, fset=None, fdel=None, doc=None, + classval=None): + self.classval = classval + property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) + + def __get__(self, obj, type=None): + if obj is None: + return self.classval + return property.__get__(self, obj, type) + + def __reduce__(self): + state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) + return InstanceProperty, state + + +class Compose(object): + """ A composition of functions + + See Also: + compose + """ + __slots__ = 'first', 'funcs' + + def __init__(self, funcs): + funcs = tuple(reversed(funcs)) + self.first = funcs[0] + self.funcs = funcs[1:] + + def __call__(self, *args, **kwargs): + ret = self.first(*args, **kwargs) + for f in self.funcs: + ret = f(ret) + return ret + + def __getstate__(self): + return self.first, self.funcs + + def __setstate__(self, state): + self.first, self.funcs = state + + @instanceproperty(classval=__doc__) + def __doc__(self): + def composed_doc(*fs): + """Generate a docstring for the composition of fs. + """ + if not fs: + # Argument name for the docstring. + return '*args, **kwargs' + + return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) + + try: + return ( + 'lambda *args, **kwargs: ' + + composed_doc(*reversed((self.first,) + self.funcs)) + ) + except AttributeError: + # One of our callables does not have a `__name__`, whatever. + return 'A composition of functions' + + @property + def __name__(self): + try: + return '_of_'.join( + (f.__name__ for f in reversed((self.first,) + self.funcs)) + ) + except AttributeError: + return type(self).__name__ + + def __repr__(self): + return '{.__class__.__name__}{!r}'.format( + self, tuple(reversed((self.first,) + self.funcs))) + + def __eq__(self, other): + if isinstance(other, Compose): + return other.first == self.first and other.funcs == self.funcs + return NotImplemented + + def __ne__(self, other): + equality = self.__eq__(other) + return NotImplemented if equality is NotImplemented else not equality + + def __hash__(self): + return hash(self.first) ^ hash(self.funcs) + + # Mimic the descriptor behavior of python functions. + # i.e. let Compose be called as a method when bound to a class. + # adapted from + # docs.python.org/3/howto/descriptor.html#functions-and-methods + def __get__(self, obj, objtype=None): + return self if obj is None else MethodType(self, obj) + + # introspection with Signature is only possible from py3.3+ + @instanceproperty + def __signature__(self): + base = inspect.signature(self.first) + last = inspect.signature(self.funcs[-1]) + return base.replace(return_annotation=last.return_annotation) + + __wrapped__ = instanceproperty(attrgetter('first')) + + +def compose(*funcs): + """ Compose functions to operate in series. + + Returns a function that applies other functions in sequence. + + Functions are applied from right to left so that + ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. + + If no arguments are provided, the identity function (f(x) = x) is returned. + + >>> inc = lambda i: i + 1 + >>> compose(str, inc)(3) + '4' + """ + if not funcs: + return identity + if len(funcs) == 1: + return funcs[0] + else: + return Compose(funcs) + + +def pipe(*funcs): + """ Pipe a value through a sequence of functions + + I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` + + We think of the value as progressing through a pipe of several + transformations, much like pipes in UNIX + + + >>> double = lambda i: 2 * i + >>> pipe(double, str)(3) + '6' + """ + return compose(*reversed(funcs)) diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py new file mode 100644 index 000000000..c285e561a --- /dev/null +++ b/brainpy/_src/tools/tests/test_functions.py @@ -0,0 +1,24 @@ + +import unittest + +import brainpy as bp +import brainpy.math as bm + + +class TestFunction(unittest.TestCase): + def test_compose(self): + f = lambda a: a + 1 + g = lambda a: a * 10 + fun1 = bp.tools.compose(f, g) + fun2 = bp.tools.pipe(g, f) + + arr = bm.random.randn(10) + r1 = fun1(arr) + r2 = fun2(arr) + groundtruth = f(g(arr)) + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, groundtruth)) + bm.clear_buffer_memory() + + + diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index e4570f6fd..3b0c3f517 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -12,7 +12,7 @@ arccos as arccos, acosh as acosh, arccosh as arccosh, - add as add, + # add as add, addcdiv as addcdiv, addcmul as addcmul, angle as angle, diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 548a987d0..a488e0742 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -58,4 +58,6 @@ from brainpy._src.math.object_transform.tools import ( eval_shape as eval_shape, ) - +from brainpy._src.math.object_transform.variables import ( + VariableStack as VariableStack, +) diff --git a/brainpy/tools.py b/brainpy/tools.py index 0f3a4c0ef..35e98f6d6 100644 --- a/brainpy/tools.py +++ b/brainpy/tools.py @@ -43,6 +43,10 @@ from brainpy._src.tools.install import ( jaxlib_install_info, ) +from brainpy._src.tools.functions import ( + compose as compose, + pipe as pipe, +) diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst index 5c8cba0fd..0b78315ab 100644 --- a/docs/advanced_tutorials.rst +++ b/docs/advanced_tutorials.rst @@ -3,13 +3,52 @@ Advanced Tutorials This section contains tutorials that illustrate more advanced features of BrainPy. +Advanced Math +------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 1 + + tutorial_advanced/compilation.ipynb + tutorial_advanced/differentiation.ipynb + + +Interoperation +-------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/integrate_flax_into_brainpy.ipynb + tutorial_advanced/integrate_bp_lif_into_flax.ipynb + tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb + + +Brain Dynamics Dedicated Operators +---------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/operator_custom_with_numba.ipynb + tutorial_advanced/operator_custom_with_taichi.ipynb + + +Developer Guides +---------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/contributing.md + + +Others +------ + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/advanced_lowdim_analysis.ipynb - tutorial_advanced/1_advanced_math.rst - tutorial_advanced/2_interoperation.rst - tutorial_advanced/3_dedicated_operators.rst - tutorial_advanced/4_developer_guides.rst - tutorial_advanced/5_others.rst diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 754e0d81d..9ed9cf46a 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -77,4 +77,5 @@ Helpers for Object-oriented Transformations :template: classtemplate.rst eval_shape + VariableStack diff --git a/docs/toolboxes.rst b/docs/toolboxes.rst index 11bf53115..cc3a38575 100644 --- a/docs/toolboxes.rst +++ b/docs/toolboxes.rst @@ -1,7 +1,16 @@ BDP Toolboxes ================== + + + This section contains detailed toolboxes BrainPy uses for brain dynamics modeling. + + +Differential Equations +----------------------- + + .. toctree:: :maxdepth: 1 @@ -10,11 +19,34 @@ This section contains detailed toolboxes BrainPy uses for brain dynamics modelin tutorial_toolbox/fde_numerical_solvers tutorial_toolbox/dde_numerical_solvers tutorial_toolbox/joint_equations + + +Toolbox for Modeling +------------------- + +.. toctree:: + :maxdepth: 1 + tutorial_toolbox/synaptic_connections tutorial_toolbox/synaptic_weights + tutorial_toolbox/inputs + + +Toolbox for Training +-------------------- + +.. toctree:: + :maxdepth: 1 + tutorial_toolbox/optimizers - tutorial_toolbox/state_saving_and_loading.ipynb - tutorial_toolbox/state_resetting.ipynb tutorial_toolbox/surrogate_gradient - tutorial_toolbox/inputs + +State Resetting, Saving and Loading +----------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_toolbox/state_saving_and_loading.ipynb + tutorial_toolbox/state_resetting.ipynb \ No newline at end of file diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 7c9a1c876..57d18332b 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -3,11 +3,76 @@ BDP Tutorials This section contains tutorials on how to use BrainPy to accomplish model building, simulation, training, and analysis. + +Math Foundation +--------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_math/variables + tutorial_math/control_flows + tutorial_math/Numpy_like_Operations.ipynb + tutorial_math/Dedicated_Operators.ipynb + tutorial_math/einops_in_brainpy.ipynb + + +Model Building with Existing Modules +------------------------------------ + +.. toctree:: + :maxdepth: 1 + + tutorial_building/overview_of_dynamic_model + tutorial_building/build_conductance_neurons_v2.ipynb + tutorial_building/phenon_synapse_models.ipynb + tutorial_building/kinetic_synapse_models.ipynb + tutorial_building/build_network_models + + +Model Building by Customizing New Modules +----------------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_building/customize_neuron_models + tutorial_building/customize_synapse_models + tutorial_building/how_to_customze_a_synapse.ipynb + + +Model Simulation +---------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_simulation/simulation_dsrunner.ipynb + tutorial_simulation/parallel_for_parameter_exploration.ipynb + tutorial_simulation/monitor_per_multiple_steps.ipynb + + +Model Training +-------------- + +This tutorial shows how to train a dynamical system from data or task. + +.. toctree:: + :maxdepth: 1 + + tutorial_training/build_training_models.ipynb + tutorial_training/offline_training.ipynb + tutorial_training/online_training.ipynb + tutorial_training/bp_training.ipynb + tutorial_training/esn_introduction.ipynb + + +Model Analysis +-------------- + .. toctree:: - :maxdepth: 2 + :maxdepth: 1 - tutorial_math/index - tutorial_building/index - tutorial_simulation/index - tutorial_training/index - tutorial_analysis/index + tutorial_analysis/lowdim_analysis + tutorial_analysis/highdim_analysis + tutorial_analysis/decision_making_model