From 8c1fdf2855d35df67ba2c50c5c186d99f22b575f Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 18 Feb 2024 21:39:20 +0800 Subject: [PATCH 1/8] [tools] add `brainpy.tools.compose` and `brainpy.tools.pipe` (#624) * [tools] add `brainpy.tools.compose` and `brainpy.tools.pipe` * fix --- brainpy/_src/tools/functions.py | 192 +++++++++++++++++++++ brainpy/_src/tools/tests/test_functions.py | 24 +++ brainpy/tools.py | 5 + 3 files changed, 221 insertions(+) create mode 100644 brainpy/_src/tools/functions.py create mode 100644 brainpy/_src/tools/tests/test_functions.py diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py new file mode 100644 index 000000000..cbc710dba --- /dev/null +++ b/brainpy/_src/tools/functions.py @@ -0,0 +1,192 @@ +import inspect +from functools import partial +from operator import attrgetter +from types import MethodType + +__all__ = [ + 'compose', 'pipe' +] + + +def identity(x): + """ Identity function. Return x + + >>> identity(3) + 3 + """ + return x + + +def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): + """ Like @property, but returns ``classval`` when used as a class attribute + + >>> class MyClass(object): + ... '''The class docstring''' + ... @instanceproperty(classval=__doc__) + ... def __doc__(self): + ... return 'An object docstring' + ... @instanceproperty + ... def val(self): + ... return 42 + ... + >>> MyClass.__doc__ + 'The class docstring' + >>> MyClass.val is None + True + >>> obj = MyClass() + >>> obj.__doc__ + 'An object docstring' + >>> obj.val + 42 + """ + if fget is None: + return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, + classval=classval) + return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, + classval=classval) + + +class InstanceProperty(property): + """ Like @property, but returns ``classval`` when used as a class attribute + + Should not be used directly. Use ``instanceproperty`` instead. + """ + + def __init__(self, fget=None, fset=None, fdel=None, doc=None, + classval=None): + self.classval = classval + property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) + + def __get__(self, obj, type=None): + if obj is None: + return self.classval + return property.__get__(self, obj, type) + + def __reduce__(self): + state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) + return InstanceProperty, state + + +class Compose(object): + """ A composition of functions + + See Also: + compose + """ + __slots__ = 'first', 'funcs' + + def __init__(self, funcs): + funcs = tuple(reversed(funcs)) + self.first = funcs[0] + self.funcs = funcs[1:] + + def __call__(self, *args, **kwargs): + ret = self.first(*args, **kwargs) + for f in self.funcs: + ret = f(ret) + return ret + + def __getstate__(self): + return self.first, self.funcs + + def __setstate__(self, state): + self.first, self.funcs = state + + @instanceproperty(classval=__doc__) + def __doc__(self): + def composed_doc(*fs): + """Generate a docstring for the composition of fs. + """ + if not fs: + # Argument name for the docstring. + return '*args, **kwargs' + + return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) + + try: + return ( + 'lambda *args, **kwargs: ' + + composed_doc(*reversed((self.first,) + self.funcs)) + ) + except AttributeError: + # One of our callables does not have a `__name__`, whatever. + return 'A composition of functions' + + @property + def __name__(self): + try: + return '_of_'.join( + (f.__name__ for f in reversed((self.first,) + self.funcs)) + ) + except AttributeError: + return type(self).__name__ + + def __repr__(self): + return '{.__class__.__name__}{!r}'.format( + self, tuple(reversed((self.first,) + self.funcs))) + + def __eq__(self, other): + if isinstance(other, Compose): + return other.first == self.first and other.funcs == self.funcs + return NotImplemented + + def __ne__(self, other): + equality = self.__eq__(other) + return NotImplemented if equality is NotImplemented else not equality + + def __hash__(self): + return hash(self.first) ^ hash(self.funcs) + + # Mimic the descriptor behavior of python functions. + # i.e. let Compose be called as a method when bound to a class. + # adapted from + # docs.python.org/3/howto/descriptor.html#functions-and-methods + def __get__(self, obj, objtype=None): + return self if obj is None else MethodType(self, obj) + + # introspection with Signature is only possible from py3.3+ + @instanceproperty + def __signature__(self): + base = inspect.signature(self.first) + last = inspect.signature(self.funcs[-1]) + return base.replace(return_annotation=last.return_annotation) + + __wrapped__ = instanceproperty(attrgetter('first')) + + +def compose(*funcs): + """ Compose functions to operate in series. + + Returns a function that applies other functions in sequence. + + Functions are applied from right to left so that + ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. + + If no arguments are provided, the identity function (f(x) = x) is returned. + + >>> inc = lambda i: i + 1 + >>> compose(str, inc)(3) + '4' + """ + if not funcs: + return identity + if len(funcs) == 1: + return funcs[0] + else: + return Compose(funcs) + + +def pipe(*funcs): + """ Pipe a value through a sequence of functions + + I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` + + We think of the value as progressing through a pipe of several + transformations, much like pipes in UNIX + + + >>> double = lambda i: 2 * i + >>> pipe(double, str)(3) + '6' + """ + return compose(*reversed(funcs)) diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py new file mode 100644 index 000000000..c285e561a --- /dev/null +++ b/brainpy/_src/tools/tests/test_functions.py @@ -0,0 +1,24 @@ + +import unittest + +import brainpy as bp +import brainpy.math as bm + + +class TestFunction(unittest.TestCase): + def test_compose(self): + f = lambda a: a + 1 + g = lambda a: a * 10 + fun1 = bp.tools.compose(f, g) + fun2 = bp.tools.pipe(g, f) + + arr = bm.random.randn(10) + r1 = fun1(arr) + r2 = fun2(arr) + groundtruth = f(g(arr)) + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, groundtruth)) + bm.clear_buffer_memory() + + + diff --git a/brainpy/tools.py b/brainpy/tools.py index 0f3a4c0ef..233269dc5 100644 --- a/brainpy/tools.py +++ b/brainpy/tools.py @@ -45,4 +45,9 @@ ) +from brainpy._src.tools.functions import ( + compose as compose, + pipe as pipe, +) + From 48455e588528a397077ebbea66daa93ef2e617b6 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 10:29:27 +0800 Subject: [PATCH 2/8] doc hierarchy update (#630) --- docs/advanced_tutorials.rst | 51 +++++++++++++++++++++--- docs/toolboxes.rst | 38 ++++++++++++++++-- docs/tutorials.rst | 77 ++++++++++++++++++++++++++++++++++--- 3 files changed, 151 insertions(+), 15 deletions(-) diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst index 5c8cba0fd..0b78315ab 100644 --- a/docs/advanced_tutorials.rst +++ b/docs/advanced_tutorials.rst @@ -3,13 +3,52 @@ Advanced Tutorials This section contains tutorials that illustrate more advanced features of BrainPy. +Advanced Math +------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 1 + + tutorial_advanced/compilation.ipynb + tutorial_advanced/differentiation.ipynb + + +Interoperation +-------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/integrate_flax_into_brainpy.ipynb + tutorial_advanced/integrate_bp_lif_into_flax.ipynb + tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb + + +Brain Dynamics Dedicated Operators +---------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/operator_custom_with_numba.ipynb + tutorial_advanced/operator_custom_with_taichi.ipynb + + +Developer Guides +---------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/contributing.md + + +Others +------ + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/advanced_lowdim_analysis.ipynb - tutorial_advanced/1_advanced_math.rst - tutorial_advanced/2_interoperation.rst - tutorial_advanced/3_dedicated_operators.rst - tutorial_advanced/4_developer_guides.rst - tutorial_advanced/5_others.rst diff --git a/docs/toolboxes.rst b/docs/toolboxes.rst index 11bf53115..cc3a38575 100644 --- a/docs/toolboxes.rst +++ b/docs/toolboxes.rst @@ -1,7 +1,16 @@ BDP Toolboxes ================== + + + This section contains detailed toolboxes BrainPy uses for brain dynamics modeling. + + +Differential Equations +----------------------- + + .. toctree:: :maxdepth: 1 @@ -10,11 +19,34 @@ This section contains detailed toolboxes BrainPy uses for brain dynamics modelin tutorial_toolbox/fde_numerical_solvers tutorial_toolbox/dde_numerical_solvers tutorial_toolbox/joint_equations + + +Toolbox for Modeling +------------------- + +.. toctree:: + :maxdepth: 1 + tutorial_toolbox/synaptic_connections tutorial_toolbox/synaptic_weights + tutorial_toolbox/inputs + + +Toolbox for Training +-------------------- + +.. toctree:: + :maxdepth: 1 + tutorial_toolbox/optimizers - tutorial_toolbox/state_saving_and_loading.ipynb - tutorial_toolbox/state_resetting.ipynb tutorial_toolbox/surrogate_gradient - tutorial_toolbox/inputs + +State Resetting, Saving and Loading +----------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_toolbox/state_saving_and_loading.ipynb + tutorial_toolbox/state_resetting.ipynb \ No newline at end of file diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 7c9a1c876..57d18332b 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -3,11 +3,76 @@ BDP Tutorials This section contains tutorials on how to use BrainPy to accomplish model building, simulation, training, and analysis. + +Math Foundation +--------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_math/variables + tutorial_math/control_flows + tutorial_math/Numpy_like_Operations.ipynb + tutorial_math/Dedicated_Operators.ipynb + tutorial_math/einops_in_brainpy.ipynb + + +Model Building with Existing Modules +------------------------------------ + +.. toctree:: + :maxdepth: 1 + + tutorial_building/overview_of_dynamic_model + tutorial_building/build_conductance_neurons_v2.ipynb + tutorial_building/phenon_synapse_models.ipynb + tutorial_building/kinetic_synapse_models.ipynb + tutorial_building/build_network_models + + +Model Building by Customizing New Modules +----------------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_building/customize_neuron_models + tutorial_building/customize_synapse_models + tutorial_building/how_to_customze_a_synapse.ipynb + + +Model Simulation +---------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_simulation/simulation_dsrunner.ipynb + tutorial_simulation/parallel_for_parameter_exploration.ipynb + tutorial_simulation/monitor_per_multiple_steps.ipynb + + +Model Training +-------------- + +This tutorial shows how to train a dynamical system from data or task. + +.. toctree:: + :maxdepth: 1 + + tutorial_training/build_training_models.ipynb + tutorial_training/offline_training.ipynb + tutorial_training/online_training.ipynb + tutorial_training/bp_training.ipynb + tutorial_training/esn_introduction.ipynb + + +Model Analysis +-------------- + .. toctree:: - :maxdepth: 2 + :maxdepth: 1 - tutorial_math/index - tutorial_building/index - tutorial_simulation/index - tutorial_training/index - tutorial_analysis/index + tutorial_analysis/lowdim_analysis + tutorial_analysis/highdim_analysis + tutorial_analysis/decision_making_model From 4d7481699c214e7b1d416d4b0e99073c1113d4aa Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 13:19:26 +0800 Subject: [PATCH 3/8] Standardizing and generalizing object-oriented transformations (#628) * test improvement * remove pytorch add * variable evaluation using `brainpy.math.eval_shape` * fix bugs * update transformations * remove `new_transform` API * update * update * fix * fix * fix bugs * fix bugs * updates * updates * upgrade * add `brainpy.math.VariableStack` --- brainpy/_src/dnn/conv.py | 11 +- brainpy/_src/dnn/tests/test_activation.py | 3 +- brainpy/_src/dnn/tests/test_conv_layers.py | 11 +- brainpy/_src/dnn/tests/test_function.py | 6 +- brainpy/_src/dnn/tests/test_linear.py | 5 +- brainpy/_src/dnn/tests/test_mode.py | 5 +- brainpy/_src/dnn/tests/test_normalization.py | 5 +- brainpy/_src/dnn/tests/test_pooling_layers.py | 2 +- .../_src/math/object_transform/autograd.py | 45 +- brainpy/_src/math/object_transform/base.py | 4 +- .../_src/math/object_transform/controls.py | 136 +++--- brainpy/_src/math/object_transform/jit.py | 69 +-- brainpy/_src/math/object_transform/naming.py | 3 +- .../_src/math/object_transform/parallels.py | 460 ------------------ brainpy/_src/math/object_transform/tools.py | 75 ++- .../_src/math/object_transform/variables.py | 45 +- brainpy/math/compat_pytorch.py | 2 +- brainpy/math/oo_transform.py | 4 + docs/apis/brainpy.math.oo_transform.rst | 1 + examples/dynamics_simulation/ei_nets.py | 2 +- 20 files changed, 208 insertions(+), 686 deletions(-) delete mode 100644 brainpy/_src/math/object_transform/parallels.py diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index deead1f3b..e4b6e25d2 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -160,7 +160,7 @@ def update(self, x): nonbatching = False if x.ndim == self.num_spatial_dims + 1: nonbatching = True - x = x.unsqueeze(0) + x = bm.unsqueeze(x, 0) w = self.w.value if self.mask is not None: try: @@ -190,6 +190,9 @@ def __repr__(self): class Conv1d(_GeneralConv): """One-dimensional convolution. + The input should a 2d array with the shape of ``[H, C]``, or + a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. + Parameters ---------- in_channels: int @@ -282,6 +285,9 @@ def _check_input_dim(self, x): class Conv2d(_GeneralConv): """Two-dimensional convolution. + The input should a 3d array with the shape of ``[H, W, C]``, or + a 4d array with the shape of ``[B, H, W, C]``. + Parameters ---------- in_channels: int @@ -375,6 +381,9 @@ def _check_input_dim(self, x): class Conv3d(_GeneralConv): """Three-dimensional convolution. + The input should a 3d array with the shape of ``[H, W, D, C]``, or + a 4d array with the shape of ``[B, H, W, D, C]``. + Parameters ---------- in_channels: int diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index ba2a49efd..17054667d 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,5 +1,6 @@ -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 3c9fdfa87..05f523622 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,17 +1,15 @@ # -*- coding: utf-8 -*- -from unittest import TestCase -from absl.testing import absltest import jax.numpy as jnp -import brainpy.math as bm +from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm class TestConv(parameterized.TestCase): def test_Conv2D_img(self): - bm.random.seed() img = jnp.zeros((2, 200, 198, 4)) for k in range(4): x = 30 + 60 * k @@ -24,6 +22,7 @@ def test_Conv2D_img(self): strides=(2, 1), padding='VALID', groups=4) out = net(img) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 99, 196, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(img)[0, :, :, 0]) @@ -31,7 +30,6 @@ def test_Conv2D_img(self): bm.clear_buffer_memory() def test_conv1D(self): - bm.random.seed() with bp.math.training_environment(): model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) @@ -39,6 +37,7 @@ def test_conv1D(self): out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :]) @@ -54,6 +53,7 @@ def test_conv2D(self): out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :, 31]) @@ -67,6 +67,7 @@ def test_conv3D(self): input = bp.math.ones((2, 5, 5, 5, 3)) out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 5, 32)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 269fec441..9ad15938d 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -1,12 +1,10 @@ # -*- coding: utf-8 -*- -from unittest import TestCase - -import jax.numpy as jnp -import brainpy.math as bm from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class TestFunction(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 7fc89526c..df5293ab9 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,6 +1,7 @@ -import brainpy as bp -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + +import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 0d754976f..3cf923d7b 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -1,7 +1,8 @@ -import brainpy.math as bm -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class Test_Conv(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index fdc5b34e3..de2c9765b 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,7 +1,8 @@ -import brainpy.math as bm -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class Test_Normalization(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 34f8f5cd5..5748edd8b 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f5e091675..ad8a5ccf6 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -28,10 +28,8 @@ get_stack_cache, cache_stack) from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, - VariableStack, - current_transform_number, - new_transform) +from .variables import (Variable, VariableStack) +from .tools import eval_shape __all__ = [ 'grad', # gradient of scalar function @@ -203,36 +201,21 @@ def __call__(self, *args, **kwargs): elif not self._eval_dyn_vars: # evaluate dynamical variables stack = get_stack_cache(self.target) if stack is None: - with new_transform(self): - with VariableStack() as stack: - if current_transform_number() > 1: - rets = self._transform( - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs - ) - else: - rets = jax.eval_shape( - self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs - ) + with VariableStack() as stack: + rets = eval_shape(self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if current_transform_number(): - return self._return(rets) - else: - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + # if not the outermost transformation + if not stack.is_first_stack(): + return self._return(rets) rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index aaf053ae7..c52845a06 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -6,7 +6,6 @@ """ import numbers -import os import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional @@ -14,14 +13,13 @@ import jax import numpy as np -from brainpy import errors +from brainpy._src.math.modes import Mode from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) -from brainpy._src.math.modes import Mode from brainpy._src.math.sharding import BATCH_AXIS variable_ = None diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 032a0fab6..3edeb08e8 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -21,17 +21,12 @@ cache_stack ) from .tools import ( - evaluate_dyn_vars, + eval_shape, dynvar_deprecation, node_deprecation, abstract ) -from .variables import ( - Variable, - VariableStack, - new_transform, - current_transform_number, -) +from .variables import (Variable, VariableStack) __all__ = [ 'make_loop', @@ -542,15 +537,13 @@ def cond( node_deprecation(child_objs) dyn_vars = get_stack_cache((true_fun, false_fun)) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('cond'): - dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars = dyn_vars1 + dyn_vars2 - cache_stack((true_fun, false_fun), dyn_vars) - if current_transform_number() > 0: - return rets + if not jax.config.jax_disable_jit and dyn_vars is None: + with VariableStack() as dyn_vars: + rets = eval_shape(true_fun, *operands, with_stack=True)[1] + _ = eval_shape(false_fun, *operands, with_stack=True) + cache_stack((true_fun, false_fun), dyn_vars) + if not dyn_vars.is_first_stack(): + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -681,20 +674,16 @@ def ifelse( else: dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: - with new_transform('ifelse'): - with VariableStack() as dyn_vars: - if current_transform_number() > 1: - rets = [branch(*operands) for branch in branches] - else: - rets = [jax.eval_shape(branch, *operands) for branch in branches] - trees = [jax.tree_util.tree_structure(ret) for ret in rets] - if not _all_equal(trees): - msg = 'All returns in branches should have the same tree structure. But we got:\n' - for tree in trees: - msg += f'- {tree}\n' - raise TypeError(msg) + with VariableStack() as dyn_vars: + rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] + trees = [jax.tree_util.tree_structure(ret) for ret in rets] + if not _all_equal(trees): + msg = 'All returns in branches should have the same tree structure. But we got:\n' + for tree in trees: + msg += f'- {tree}\n' + raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) - if current_transform_number(): + if not dyn_vars.is_first_stack(): return rets[0] branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches] @@ -880,28 +869,23 @@ def for_loop( if jit is None: # jax disable jit jit = not jax.config.jax_disable_jit - dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) + stack = get_stack_cache((body_fun, unroll_kwargs)) if jit: - if dyn_vars is None: + if stack is None: + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, + remat, reverse, unroll, unroll_kwargs) # TODO: better cache mechanism? - with new_transform('for_loop'): - with VariableStack() as dyn_vars: - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, - progress_bar, remat, reverse, unroll, - unroll_kwargs) - if current_transform_number() > 1: - rets = transform(operands) - else: - rets = jax.eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache - if current_transform_number(): + with VariableStack() as stack: + rets = eval_shape(transform, operands) + cache_stack((body_fun, unroll_kwargs), stack) # cache + if not stack.is_first_stack(): return rets[1] del rets else: - dyn_vars = VariableStack() + stack = VariableStack() # TODO: cache mechanism? - transform = _get_for_loop_transform(body_fun, dyn_vars, bar, + transform = _get_for_loop_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll, unroll_kwargs) if jit: @@ -909,11 +893,11 @@ def for_loop( else: with jax.disable_jit(): dyn_vals, out_vals = transform(operands) - for key in dyn_vars.keys(): - dyn_vars[key]._value = dyn_vals[key] + for key in stack.keys(): + stack[key]._value = dyn_vals[key] if progress_bar: bar.close() - del dyn_vals, dyn_vars + del dyn_vals, stack return out_vals @@ -1011,26 +995,21 @@ def scan( num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) bar = tqdm(total=num_total) - dyn_vars = get_stack_cache(body_fun) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('scan'): - with VariableStack() as dyn_vars: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - if current_transform_number() > 1: - rets = transform(init, operands) - else: - rets = jax.eval_shape(transform, init, operands) - cache_stack(body_fun, dyn_vars) # cache - if current_transform_number(): - return rets[0][1], rets[1] - del rets - - dyn_vars = VariableStack() if dyn_vars is None else dyn_vars - transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) + stack = get_stack_cache(body_fun) + if not jax.config.jax_disable_jit and stack is None: + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + with VariableStack() as stack: + rets = eval_shape(transform, init, operands) + cache_stack(body_fun, stack) # cache + if not stack.is_first_stack(): + return rets[0][1], rets[1] + del rets + + stack = VariableStack() if stack is None else stack + transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) (dyn_vals, carry), out_vals = transform(init, operands) - for key in dyn_vars.keys(): - dyn_vars[key]._value = dyn_vals[key] + for key in stack.keys(): + stack[key]._value = dyn_vals[key] if progress_bar: bar.close() return carry, out_vals @@ -1129,7 +1108,6 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. - """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1137,18 +1115,16 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - dyn_vars = get_stack_cache((body_fun, cond_fun)) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('while_loop'): - dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars = dyn_vars1 + dyn_vars2 - cache_stack((body_fun, cond_fun), dyn_vars) - if current_transform_number(): - return rets - dyn_vars = VariableStack() if dyn_vars is None else dyn_vars - dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) - for k, v in dyn_vars.items(): + stack = get_stack_cache((body_fun, cond_fun)) + if not jax.config.jax_disable_jit and stack is None: + with VariableStack() as stack: + _ = eval_shape(cond_fun, *operands, with_stack=True) + rets = eval_shape(body_fun, *operands, with_stack=True)[1] + cache_stack((body_fun, cond_fun), stack) + if not stack.is_first_stack(): + return rets + stack = VariableStack() if stack is None else stack + dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) + for k, v in stack.items(): v._value = dyn_values[k] return out diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 7bb36f4e2..73eab2f91 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,23 +11,15 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax -from jax.sharding import Sharding from brainpy import tools, check -from .tools import (dynvar_deprecation, - node_deprecation, - evaluate_dyn_vars_with_cache, - evaluate_dyn_vars, - _partial_fun) from .base import BrainPyObject, ObjectTransform from .naming import get_stack_cache, cache_stack +from .tools import (dynvar_deprecation, + node_deprecation, + eval_shape) +from .variables import (Variable, VariableStack) from ..ndarray import Array -from .variables import (Variable, - VariableStack, - outermost_transform, - transform_stack, - current_transform_number, - new_transform) RandomState = None @@ -151,16 +143,12 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): return changes, out def _get_transform(self, *args, **kwargs): - with new_transform(self): - self._dyn_vars, rets = evaluate_dyn_vars( - self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - use_eval_shape=current_transform_number() <= 1, - **kwargs - ) - + with VariableStack() as self._dyn_vars: + rets = eval_shape(self.fun, + *args, + **kwargs, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames) # in_shardings if self._in_shardings is None: in_shardings = None @@ -186,18 +174,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs): @@ -207,7 +195,7 @@ def __call__(self, *args, **kwargs): if self._transform is None: # initialize the transformation rets = self._get_transform(*args, **kwargs) # if not the outermost transformation - if current_transform_number(): + if not self._dyn_vars.is_first_stack(): return rets # call the transformed function @@ -477,15 +465,8 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - - with jax.ensure_compile_time_eval(): - if len(static_argnums) or len(static_argnames): - fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) - else: - args_, kwargs_, fun3 = args, kwargs, fun2 - with VariableStack() as stack: - _ = jax.eval_shape(fun3, *args_, **kwargs_) - del args_, kwargs_ + with VariableStack() as stack: + _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1c8ca6ef9..1181e003b 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -41,7 +41,7 @@ def get_unique_name(type_: str): return name -def clear_name_cache(ignore_warn=False): +def clear_name_cache(ignore_warn=True): """Clear the cached names.""" _name2id.clear() _typed_names.clear() @@ -57,6 +57,7 @@ def cache_stack(func, stack): def clear_stack_cache(): + """Clear the cached stack.""" for k in tuple(_fun2stack.keys()): del _fun2stack[k] diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py deleted file mode 100644 index 1eddce048..000000000 --- a/brainpy/_src/math/object_transform/parallels.py +++ /dev/null @@ -1,460 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -The parallel compilation tools for JAX backend. - -1. Vectorize compilation is implemented by the 'vmap()' function -2. Parallel compilation is implemented by the 'pmap()' function - -""" - - -import functools - -import jax -import jax.numpy as jnp -import numpy as np -from jax.interpreters.partial_eval import DynamicJaxprTracer -from jax.interpreters.partial_eval import JaxprTracer -from jax.interpreters.pxla import ShardedDeviceArray - -try: - from jax.errors import UnexpectedTracerError -except ImportError: - from jax.core import UnexpectedTracerError - -from brainpy import errors -from brainpy._src.math.random import RandomState -from brainpy._src.math.ndarray import Array -from brainpy.tools import change_func_name -from .base import BrainPyObject, ArrayCollector - -__all__ = [ - 'vmap', - 'pmap', -] - - -def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, - batch_idx, axis_name, f_name=None): - @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) - def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): - nonbatched_vars.assign(nonbatched_data) - batched_vars.assign(batched_data) - out = func(*args, **kwargs) - nonbatched_changes = nonbatched_vars.dict() - batched_changes = batched_vars.dict() - return nonbatched_changes, batched_changes, out - - def call(*args, **kwargs): - n = args[batch_idx[0]].shape[batch_idx[1]] - nonbatched_data = nonbatched_vars.dict() - batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} - try: - out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) - except UnexpectedTracerError as e: - nonbatched_vars.assign(nonbatched_data) - batched_vars.assign(batched_data) - raise errors.JaxTracerError() from e - # for key, v in dyn_changes.items(): - # dyn_vars[key] = reduce_func(v) - # for key, v in rand_changes.items(): - # rand_vars[key] = reduce_func(v) - return out - - return change_func_name(name=f_name, f=call) if f_name else call - - -def vmap(func, dyn_vars=None, batched_vars=None, - in_axes=0, out_axes=0, axis_name=None, - reduce_func=None, auto_infer=False): - """Vectorization compilation for class objects. - - Vectorized compile a function or a module to run in parallel on a single device. - - Examples - -------- - - Parameters - ---------- - func : BrainPyObject, function, callable - The function or the module to compile. - dyn_vars : dict, sequence - batched_vars : dict - in_axes : optional, int, sequence of int - Specify which input array axes to map over. If each positional argument to - ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, - or a tuple of integers and Nones with length equal to the number of - positional arguments to ``obj_or_func``. An integer or ``None`` - indicates which array axis to map over for all arguments (with ``None`` - indicating not to map any axis), and a tuple indicates which axis to map - for each corresponding positional argument. Axis integers must be in the - range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of - dimensions (axes) of the corresponding input array. - - If the positional arguments to ``obj_or_func`` are container types, the - corresponding element of ``in_axes`` can itself be a matching container, - so that distinct array axes can be mapped for different container - elements. ``in_axes`` must be a container tree prefix of the positional - argument tuple passed to ``obj_or_func``. - - At least one positional argument must have ``in_axes`` not None. The sizes - of the mapped input axes for all mapped positional arguments must all be - equal. - - Arguments passed as keywords are always mapped over their leading axis - (i.e. axis index 0). - out_axes : optional, int, tuple/list/dict - Indicate where the mapped axis should appear in the output. All outputs - with a mapped axis must have a non-None ``out_axes`` specification. Axis - integers must be in the range ``[-ndim, ndim)`` for each output array, - where ``ndim`` is the number of dimensions (axes) of the array returned - by the :func:`vmap`-ed function, which is one more than the number of - dimensions (axes) of the corresponding array returned by ``obj_or_func``. - axis_name : optional - - Returns - ------- - obj_or_func : Any - Batched/vectorized version of ``obj_or_func`` with arguments that correspond to - those of ``obj_or_func``, but with extra array axes at positions indicated by - ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but - with extra array axes at positions indicated by ``out_axes``. - - """ - # if isinstance(func, DynamicalSystem): - # if len(func.steps): # DynamicalSystem has step functions - # - # # dynamical variables - # dyn_vars = (dyn_vars or func.vars().unique()) - # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector() - # for key, val in dyn_vars.items(): - # if isinstance(val, RandomState): - # rand_vars[key] = val - # else: - # dyn_vars[key] = val - # - # # in axes - # if in_axes is None: - # in_axes = {key: (None, 0) for key in func.steps.keys()} - # elif isinstance(in_axes, int): - # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()} - # elif isinstance(in_axes, (tuple, list)): - # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()} - # elif isinstance(in_axes, dict): - # keys = list(func.steps.keys()) - # if keys[0] not in in_axes: - # in_axes = {key: (None, 0, in_axes) for key in keys} - # else: - # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys} - # assert isinstance(in_axes, dict) - # - # # batch size index - # batch_idx = {} - # for key, axes in in_axes.items(): - # for i, axis in enumerate(axes[2:]): - # if axis is not None: - # batch_idx[key] = (i, axis) - # break - # else: - # raise ValueError(f'Found no batch axis: {axes}.') - # - # # out axes - # if out_axes is None: - # out_axes = {key: 0 for key in func.steps.keys()} - # elif isinstance(out_axes, int): - # out_axes = {key: out_axes for key in func.steps.keys()} - # elif isinstance(out_axes, (tuple, list)): - # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()} - # elif isinstance(out_axes, dict): - # keys = list(func.steps.keys()) - # if keys[0] not in out_axes: - # out_axes = {key: (out_axes, 0, 0) for key in keys} - # else: - # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys} - # assert isinstance(out_axes, dict) - # - # # reduce_func - # if reduce_func is None: - # reduce_func = lambda x: x.mean(axis=0) - # - # # vectorized map functions - # for key in func.steps.keys(): - # func.steps[key] = _make_vmap(func=func.steps[key], - # dyn_vars=dyn_vars, - # rand_vars=rand_vars, - # in_axes=in_axes[key], - # out_axes=out_axes[key], - # axis_name=axis_name, - # batch_idx=batch_idx[key], - # reduce_func=reduce_func, - # f_name=key) - # - # return func - - if callable(func): - if auto_infer: - if dyn_vars is not None: - dyn_vars = dyn_vars - elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation - dyn_vars = func.vars().unique() - elif hasattr(func, '__self__'): - if isinstance(func.__self__, BrainPyObject): - dyn_vars = func.__self__.vars().unique() - - if dyn_vars is None: - return jax.vmap(func, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name) - - else: - if isinstance(dyn_vars, Array): - dyn_vars = [dyn_vars] - if isinstance(dyn_vars, (tuple, list)): - dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} - assert isinstance(dyn_vars, dict) - - # dynamical variables - _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector() - for key, val in dyn_vars.items(): - if isinstance(val, RandomState): - _rand_vars[key] = val - else: - _dyn_vars[key] = val - - # in axes - if in_axes is None: - in_axes = (None, 0) - elif isinstance(in_axes, (int, dict)): - in_axes = (None, 0, in_axes) - elif isinstance(in_axes, (tuple, list)): - in_axes = (None, 0) + tuple(in_axes) - assert isinstance(in_axes, (tuple, list)) - - # batch size index - batch_idx = {} - for key, axes in batch_idx.items(): - for i, axis in enumerate(axes[2:]): - if axis is not None: - batch_idx[key] = (i, axis) - break - else: - raise ValueError(f'Found no batch axis: {axes}.') - - # out axes - if out_axes is None: - out_axes = 0 - elif isinstance(out_axes, (int, dict)): - out_axes = (out_axes, 0, 0) - elif isinstance(out_axes, (tuple, list)): - out_axes = tuple(out_axes) + (0, 0) - assert isinstance(out_axes, (list, tuple)) - - # reduce_func - if reduce_func is None: - reduce_func = lambda x: x.mean(axis=0) - - # jit function - return _make_vmap(func=func, - nonbatched_vars=_dyn_vars, - batched_vars=_rand_vars, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name, - batch_idx=batch_idx) - - else: - raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable ' - f'function, but we got {type(func)}.') - - -def _device_reshape(x): - """Reshape an input array in order to broadcast to multiple devices.""" - num_device = jax.local_device_count() - - if not hasattr(x, 'ndim'): - raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to ' - f'parallel, first convert it to a Array, for example np.float(0.5)') - if x.ndim == 0: - return np.broadcast_to(x, [num_device]) - if x.shape[0] % num_device != 0: - raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among ' - f'{num_device} devices, but does not go equally.') - return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:]) - - -def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0, - out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, - axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None): - @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, - static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, - backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) - def pmapped_func(dyn_data, rand_data, *args, **kwargs): - dyn_vars.assign(dyn_data) - rand_vars.assign(rand_data) - out = func(*args, **kwargs) - dyn_changes = dyn_vars.dict() - rand_changes = rand_vars.dict() - return out, dyn_changes, rand_changes - - def call(*args): - un_replicated = [k for k, v in dyn_vars.items() - if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))] - if len(un_replicated): - raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.' - f'did you forget to call xx.replicate() on them?') - _args = [] - for i, x in enumerate(args): - if i + 2 in static_broadcasted_argnums: - _args.append(x) - else: - _args.append(jax.tree_map(_device_reshape, [x])[0]) - dyn_data = dyn_vars.dict() - rand_data = rand_vars.dict() - output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args) - dyn_vars.assign(dyn_changes) - rand_vars.assign(rand_changes) - return jax.tree_map(reduce_func, output) - - return change_func_name(name=f_name, f=call) if f_name else call - - -def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), - devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, - reduce_func=None): - """Parallel compilation for class objects. - - Parallel compile a function or a module to run on multiple devices in parallel. - - Parameters - ---------- - func - axis_name - in_axes - out_axes - static_broadcasted_argnums - devices - backend - axis_size - donate_argnums - global_arg_shapes - - Returns - ------- - - - Examples - -------- - - - """ - - # if isinstance(func, DynamicalSystem): - # if len(func.steps): # DynamicalSystem has step functions - # - # # dynamical variables - # all_vars = (dyn_vars or func.vars().unique()) - # dyn_vars = ArrayCollector() - # rand_vars = ArrayCollector() - # for key, val in all_vars.items(): - # if isinstance(val, RandomState): - # rand_vars[key] = val - # else: - # dyn_vars[key] = val - # - # # reduce function - # if reduce_func is None: - # reduce_func = jnp.concatenate - # - # # static broadcast-ed arguments - # if static_broadcasted_argnums is None: - # static_broadcasted_argnums = () - # elif isinstance(static_broadcasted_argnums, int): - # static_broadcasted_argnums = (static_broadcasted_argnums + 2,) - # elif isinstance(static_broadcasted_argnums, (tuple, list)): - # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) - # assert isinstance(static_broadcasted_argnums, (tuple, list)) - # - # # jit functions - # for key in func.steps.keys(): - # step = func.steps[key] - # func.steps[key] = _make_pmap(dyn_vars=dyn_vars, - # rand_vars=rand_vars, - # func=step, - # axis_name=axis_name, - # in_axes=in_axes, - # out_axes=out_axes, - # static_broadcasted_argnums=static_broadcasted_argnums, - # devices=devices, - # backend=backend, - # axis_size=axis_size, - # donate_argnums=donate_argnums, - # global_arg_shapes=global_arg_shapes, - # reduce_func=reduce_func, - # f_name=key) - # return func - - if callable(func): - if dyn_vars is not None: - dyn_vars = dyn_vars - elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation - dyn_vars = func.vars().unique() - elif hasattr(func, '__self__'): - if isinstance(func.__self__, BrainPyObject): - dyn_vars = func.__self__.vars().unique() - - if dyn_vars is None: - return jax.pmap(func, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) - else: - # dynamical variables - dyn_vars = ArrayCollector() - rand_vars = ArrayCollector() - for key, val in dyn_vars.items(): - if isinstance(val, RandomState): - rand_vars[key] = val - else: - dyn_vars[key] = val - - # static broadcast-ed arguments - if static_broadcasted_argnums is None: - static_broadcasted_argnums = () - elif isinstance(static_broadcasted_argnums, int): - static_broadcasted_argnums = (static_broadcasted_argnums + 2,) - elif isinstance(static_broadcasted_argnums, (tuple, list)): - static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) - assert isinstance(static_broadcasted_argnums, (tuple, list)) - - # reduce function - if reduce_func is None: - reduce_func = jnp.concatenate - - # jit function - func.__call__ = _make_pmap(dyn_vars=dyn_vars, - rand_vars=rand_vars, - func=func, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes, - reduce_func=reduce_func) - return func - - else: - raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, ' - f'but we got {type(func)}.') diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 7b519590a..632c6d79e 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -132,19 +132,65 @@ def evaluate_dyn_vars_with_cache( return stack +def _partial_fun2( + fun: Callable, + args: tuple, + kwargs: dict, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = () +): + num_args = len(args) + + # arguments + static_args = dict() + dyn_args = [] + dyn_arg_ids = dict() + static_argnums = list(static_argnums) + dyn_i = 0 + for i in range(num_args): + if i in static_argnums: + static_argnums.remove(i) + static_args[i] = args[i] + else: + dyn_args.append(args[i]) + dyn_arg_ids[i] = dyn_i + dyn_i += 1 + if len(static_argnums) > 0: + raise ValueError(f"Invalid static_argnums: {static_argnums}") + + # keyword arguments + static_kwargs, dyn_kwargs = {}, {} + for k, arg in kwargs.items(): + if k in static_argnames: + static_kwargs[k] = arg + else: + dyn_kwargs[k] = arg + del args, kwargs, static_argnums, static_argnames + + @wraps(fun) + def new_fun(*dynargs, **dynkwargs): + return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], + **static_kwargs, + **dynkwargs) + + return new_fun, dyn_args, dyn_kwargs + + def eval_shape( fun: Callable, *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), + with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. Args: fun: The callable function. - *args: - **kwargs: + *args: The positional arguments. + **kwargs: The keyword arguments. + with_stack: Whether evaluate the function within a local variable stack. static_argnums: The static argument indices. static_argnames: The static argument names. @@ -153,21 +199,30 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun(fun, args, kwargs, - static_argnums=static_argnums, - static_argnames=static_argnames) + f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) else: - f2, args, kwargs = fun, args, kwargs + f2 = fun # evaluate the function fun_in_eval_shape.append(fun) try: - with jax.ensure_compile_time_eval(): + if with_stack: with VariableStack() as stack: if len(fun_in_eval_shape) > 1: - returns = fun(*args, **kwargs) + returns = f2(*args, **kwargs) else: - returns = jax.eval_shape(fun, *args, **kwargs) + returns = jax.eval_shape(f2, *args, **kwargs) + else: + stack = None + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) finally: fun_in_eval_shape.pop() - return stack, returns + del f2 + if with_stack: + return stack, returns + else: + return returns + diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index 5014da0bf..b7babae8d 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple import jax @@ -190,6 +189,14 @@ def remove_by_id(self, *ids, error_when_absent=False): remove_var_by_id = remove_by_id + @classmethod + def num_of_stack(self): + return len(var_stack_list) + + @classmethod + def is_first_stack(self): + return len(var_stack_list) == 0 + def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) @@ -210,42 +217,6 @@ def __add__(self, other: dict): var_stack_list: List[VariableStack] = [] -transform_stack: List[Callable] = [] - - -@contextmanager -def new_transform(transform: Any): - transform_stack.append(transform) - try: - yield - finally: - transform_stack.pop() - - -def outermost_stack(): - if len(var_stack_list): - return var_stack_list[0] - else: - return None - - -def outermost_transform(): - if len(transform_stack): - return transform_stack[0] - else: - return None - - -def current_transform_number(): - return len(transform_stack) - - -def _stack_add_read(var: 'Variable'): - pass - - -def _stack_add_write(var: 'Variable'): - pass @register_pytree_node_class diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index e4570f6fd..3b0c3f517 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -12,7 +12,7 @@ arccos as arccos, acosh as acosh, arccosh as arccosh, - add as add, + # add as add, addcdiv as addcdiv, addcmul as addcmul, angle as angle, diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 548a987d0..7654731d8 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -59,3 +59,7 @@ eval_shape as eval_shape, ) +from brainpy._src.math.object_transform.variables import ( + VariableStack as VariableStack, +) + diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 754e0d81d..9ed9cf46a 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -77,4 +77,5 @@ Helpers for Object-oriented Transformations :template: classtemplate.rst eval_shape + VariableStack diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index f98527458..9c7daff55 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -228,7 +228,7 @@ def __init__(self): ) def update(self, input): - spk = self.delay.at('I') + spk = self.delay.at('delay') self.E(self.syn1(spk[:3200])) self.I(self.syn2(spk[3200:])) self.delay(self.N(input)) From 5eb7cee3e105dc2cf2460a0b567c3403950c752c Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Thu, 22 Feb 2024 17:14:12 +0800 Subject: [PATCH 4/8] fix #626 (#631) --- brainpy/_src/math/delayvars.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index eb8e27c8f..390e04dd7 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -11,7 +11,7 @@ from brainpy import check from brainpy.check import is_float, is_integer, jit_error from brainpy.errors import UnsupportedError -from .compat_numpy import vstack, broadcast_to +from .compat_numpy import broadcast_to, expand_dims, concatenate from .environment import get_dt, get_float from .interoperability import as_jax from .ndarray import ndarray, Array @@ -392,6 +392,7 @@ def reset( dtype=delay_target.dtype), batch_axis=batch_axis) else: + self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) @@ -472,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) + self.data.value = concatenate([expand_dims(value, 0), self.data[1:]], axis=0) else: self.data[:] = value From 64d8f54c34b8cdacd32274ca74aca7bfb639e0f6 Mon Sep 17 00:00:00 2001 From: yunhui <38786521+CloudyDory@users.noreply.github.com> Date: Thu, 22 Feb 2024 22:56:23 +0800 Subject: [PATCH 5/8] Fix delayvar not correct in concat mode (#632) --- brainpy/_src/math/delayvars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 390e04dd7..676e4286b 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -473,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = concatenate([expand_dims(value, 0), self.data[1:]], axis=0) + self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) else: self.data[:] = value From a579e411cd857ba1230aa7f816e7e40a07c49fd2 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Fri, 1 Mar 2024 10:39:33 +0800 Subject: [PATCH 6/8] [dependency] remove hard dependency of `taichi` and `numba` (#635) * try to remove hard dependency with taichi and numba * [math] Update operator selection strategy for csr matvec * [math] Remove old test case of event csr matvec and csr matvec * [dependency] remove all numba and taichi dependency * fix * Update * Update test_taichi_clean_cache.py * Update CI and remove taichi, numba from requirements * Resolve conflicts * Revert "Merge branch 'master' into dependency-optimize" This reverts commit c54c82214f6f463d9cf0cb9b53d469fe9521edd1, reversing changes made to 76202d5baa427d5698f566a1260a66d8e4f51c8e. * upgrade dependency * fix * update * update * update doc and dependency * update dependency --------- Co-authored-by: Chaoming Wang --- .github/workflows/CI.yml | 62 + README.md | 24 +- brainpy/_src/connect/random_conn.py | 2617 ++++++++------- brainpy/_src/dependency_check.py | 222 +- brainpy/_src/dnn/conv.py | 11 +- brainpy/_src/dnn/linear.py | 405 +-- brainpy/_src/dnn/tests/test_activation.py | 3 +- brainpy/_src/dnn/tests/test_conv_layers.py | 11 +- brainpy/_src/dnn/tests/test_function.py | 6 +- brainpy/_src/dnn/tests/test_linear.py | 441 +-- brainpy/_src/dnn/tests/test_mode.py | 1608 +++++----- brainpy/_src/dnn/tests/test_normalization.py | 5 +- brainpy/_src/dnn/tests/test_pooling_layers.py | 2 +- .../_src/dyn/projections/tests/test_STDP.py | 247 +- .../_src/dyn/projections/tests/test_aligns.py | 883 +++--- .../synapses/tests/test_abstract_synapses.py | 256 +- .../tests/test_biological_synapses.py | 211 +- brainpy/_src/math/defaults.py | 18 +- brainpy/_src/math/delayvars.py | 5 +- brainpy/_src/math/environment.py | 23 +- brainpy/_src/math/event/__init__.py | 2 - brainpy/_src/math/event/_csr_matvec.py | 1271 ++------ brainpy/_src/math/event/_info_collection.py | 198 -- .../tests/event_info_VS_jax_operators.py | 275 -- .../_src/math/event/tests/test_event_csrmv.py | 7 + .../math/event/tests/test_event_csrmv_old.py | 324 -- brainpy/_src/math/event/tests/test_info.py | 62 - .../_src/math/event/tests/test_info_gpu.py | 14 - brainpy/_src/math/index_tricks.py | 305 -- brainpy/_src/math/jitconn/__init__.py | 5 +- brainpy/_src/math/jitconn/_event_matvec.py | 2821 ++++++----------- brainpy/_src/math/jitconn/_matvec.py | 2175 ++++--------- .../math/jitconn/tests/test_event_matvec.py | 6 + .../_src/math/jitconn/tests/test_matvec.py | 5 + .../_src/math/object_transform/autograd.py | 45 +- brainpy/_src/math/object_transform/base.py | 4 +- .../_src/math/object_transform/controls.py | 136 +- brainpy/_src/math/object_transform/jit.py | 69 +- brainpy/_src/math/object_transform/naming.py | 3 +- .../_src/math/object_transform/parallels.py | 460 +++ brainpy/_src/math/object_transform/tools.py | 75 +- .../_src/math/object_transform/variables.py | 45 +- brainpy/_src/math/op_register/__init__.py | 15 +- brainpy/_src/math/op_register/base.py | 30 +- .../op_register/numba_approach/__init__.py | 14 +- .../numba_approach/cpu_translation.py | 298 +- brainpy/_src/math/op_register/numba_based.py | 13 +- .../math/op_register/tests/test_ad_support.py | 7 +- .../op_register/tests/test_numba_based.py | 7 +- .../op_register/tests/test_taichi_based.py | 7 +- .../tests/test_taichi_clean_cache.py | 110 +- brainpy/_src/math/sparse/__init__.py | 5 +- brainpy/_src/math/sparse/_bsr_mm.py | 100 +- brainpy/_src/math/sparse/_csr_mv.py | 896 ++---- brainpy/_src/math/sparse/_utils.py | 3 +- brainpy/_src/math/sparse/tests/test_csrmv.py | 6 +- .../_src/math/sparse/tests/test_csrmv_old.py | 352 -- brainpy/_src/math/tests/test_tifunc.py | 246 +- brainpy/_src/math/tifunc.py | 513 ++- brainpy/_src/tests/test_dyn_runner.py | 267 +- brainpy/_src/tools/functions.py | 192 -- brainpy/_src/tools/progress.py | 519 +++ brainpy/_src/tools/tests/test_functions.py | 24 - brainpy/errors.py | 11 +- brainpy/math/__init__.py | 205 +- brainpy/math/compat_pytorch.py | 2 +- brainpy/math/event.py | 2 - brainpy/math/jitconn.py | 20 +- brainpy/math/oo_transform.py | 4 - brainpy/math/op_register.py | 26 +- brainpy/math/sparse.py | 7 +- brainpy/math/tifunc.py | 51 +- brainpy/tools.py | 5 - docs/advanced_tutorials.rst | 51 +- docs/apis/brainpy.math.oo_transform.rst | 1 - docs/quickstart/installation.rst | 262 +- docs/toolboxes.rst | 38 +- docs/tutorials.rst | 77 +- examples/dynamics_simulation/ei_nets.py | 2 +- examples/dynamics_training/integrator_rnn.py | 4 +- requirements-dev-raw.txt | 12 + requirements-dev.txt | 5 +- requirements.txt | 2 - setup.py | 16 +- 84 files changed, 7956 insertions(+), 11838 deletions(-) delete mode 100644 brainpy/_src/math/event/_info_collection.py delete mode 100644 brainpy/_src/math/event/tests/event_info_VS_jax_operators.py delete mode 100644 brainpy/_src/math/event/tests/test_event_csrmv_old.py delete mode 100644 brainpy/_src/math/event/tests/test_info.py delete mode 100644 brainpy/_src/math/event/tests/test_info_gpu.py delete mode 100644 brainpy/_src/math/index_tricks.py create mode 100644 brainpy/_src/math/object_transform/parallels.py delete mode 100644 brainpy/_src/math/sparse/tests/test_csrmv_old.py delete mode 100644 brainpy/_src/tools/functions.py create mode 100644 brainpy/_src/tools/progress.py delete mode 100644 brainpy/_src/tools/tests/test_functions.py create mode 100644 requirements-dev-raw.txt diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 84aa028e3..95bd8eafd 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,6 +50,37 @@ jobs: cd brainpy pytest _src/ + test_linux_with_taichi_numba: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ "3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest taichi numba + if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi + pip uninstall brainpy -y + python setup.py install + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + cd brainpy + pytest _src/ + # test_linux_py37: # runs-on: ubuntu-latest @@ -116,6 +147,37 @@ jobs: cd brainpy pytest _src/ + test_macos_with_taichi_numba: + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest taichi numba + if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi + pip uninstall brainpy -y + python setup.py install + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + cd brainpy + pytest _src/ + # test_macos_py37: # runs-on: macos-latest # strategy: diff --git a/README.md b/README.md index 6d2ee4bf4..a7fe0b721 100644 --- a/README.md +++ b/README.md @@ -25,29 +25,7 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu ## Installation -BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Install the latest version of BrainPy: - -```bash -$ pip install brainpy -U -``` - -In addition, many customized operators in BrainPy are implemented in ``brainpylib``. -Install the latest version of `brainpylib` by: - -```bash -# CPU installation for Linux, macOS and Windows -$ pip install --upgrade brainpylib -``` - -```bash -# CUDA 12 installation for Linux only -$ pip install --upgrade brainpylib-cu12x -``` - -```bash -# CUDA 11 installation for Linux only -$ pip install --upgrade brainpylib-cu11x -``` +BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. For detailed installation instructions, please refer to the documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 1f5b1db6d..0e4ee769c 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -1,1372 +1,1245 @@ -# -*- coding: utf-8 -*- -from functools import partial -from typing import Optional - -from jax import vmap, jit, numpy as jnp -import numpy as np -from numba import njit - -import brainpy.math as bm -from brainpy.errors import ConnectorError -from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed -from brainpy._src.tools.package import SUPPORT_NUMBA -from .base import * - -__all__ = [ - 'FixedProb', - 'FixedPreNum', - 'FixedPostNum', - 'FixedTotalNum', - 'GaussianProb', - 'ProbDist', - - 'SmallWorld', - 'ScaleFreeBA', - 'ScaleFreeBADual', - 'PowerLaw', -] - - -class FixedProb(TwoEndConnector): - """Connect the post-synaptic neurons with fixed probability. - - Parameters - ---------- - prob: float - The conn probability. - pre_ratio: float - The ratio of pre-synaptic neurons to connect. - include_self : bool - Whether create (i, i) conn? - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - seed : optional, int - Seed the random generator. - """ - - def __init__(self, - prob, - pre_ratio=1., - include_self=True, - allow_multi_conn=False, - seed=None, - **kwargs): - super(FixedProb, self).__init__(**kwargs) - assert 0. <= prob <= 1. - assert 0. <= pre_ratio <= 1. - self.prob = prob - self.pre_ratio = pre_ratio - self.include_self = include_self - self.seed = format_seed(seed) - self.allow_multi_conn = allow_multi_conn - self._jaxrand = bm.random.default_rng(self.seed) - self._nprand = np.random.RandomState(self.seed) - - def __repr__(self): - return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' - f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' - f'seed={self.seed})') - - def _iii(self): - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - - if self.pre_ratio < 1.: - pre_num_to_select = int(self.pre_num * self.pre_ratio) - pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) - else: - pre_num_to_select = self.pre_num - pre_ids = jnp.arange(self.pre_num) - - post_num_total = self.post_num - post_num_to_select = int(self.post_num * self.prob) - - if self.allow_multi_conn: - selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self._nprand.randint(0, int(1e8))) - else: - rng = self._nprand - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(pre_num_to_select): - posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) - return posts - - selected_post_ids = jnp.asarray(single_conn()) - return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) - - def build_coo(self): - _, post_num_to_select, selected_post_ids, pre_ids = self._iii() - selected_post_ids = selected_post_ids.flatten() - selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) - if not self.include_self: - true_ids = selected_pre_ids != selected_post_ids - selected_pre_ids = selected_pre_ids[true_ids] - selected_post_ids = selected_post_ids[true_ids] - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def build_csr(self): - pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() - pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select - if not self.include_self: - true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) - pre_nums -= jnp.sum(true_ids, axis=1) - selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_post_ids = selected_post_ids.flatten() - selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) - return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) - - def build_mat(self): - if self.pre_ratio < 1.: - pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state - else: - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) - mat = bm.asarray(mat) - if not self.include_self: - bm.fill_diagonal(mat, False) - return mat.astype(MAT_DTYPE) - - -class FixedTotalNum(TwoEndConnector): - """Connect the synaptic neurons with fixed total number. - - Parameters - ---------- - num : float,int - The conn total number. - allow_multi_conn : bool, optional - Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. - seed: int, optional - The random number seed. - """ - - def __init__(self, - num, - allow_multi_conn=False, - seed=None, **kwargs): - super().__init__(**kwargs) - if isinstance(num, int): - assert num >= 0, '"num" must be a non-negative integer.' - elif isinstance(num, float): - assert 0. <= num <= 1., '"num" must be in [0., 1.).' - else: - raise ConnectorError(f'Unknown type: {type(num)}') - self.num = num - self.seed = format_seed(seed) - self.allow_multi_conn = allow_multi_conn - self.rng = bm.random.RandomState(self.seed) - - def build_coo(self): - mat_element_num = self.pre_num * self.post_num - if self.num > mat_element_num: - raise ConnectorError(f'"num" must be smaller than "all2all num", ' - f'but got {self.num} > {mat_element_num}') - if self.allow_multi_conn: - selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) - selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) - else: - index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) - selected_pre_ids = index // self.post_num - selected_post_ids = index % self.post_num - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def __repr__(self): - return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})' - - -class FixedNum(TwoEndConnector): - def __init__(self, - num, - include_self=True, - allow_multi_conn=False, - seed=None, - **kwargs): - super(FixedNum, self).__init__(**kwargs) - if isinstance(num, int): - assert num >= 0, '"num" must be a non-negative integer.' - elif isinstance(num, float): - assert 0. <= num <= 1., '"num" must be in [0., 1.).' - else: - raise ConnectorError(f'Unknown type: {type(num)}') - self.num = num - self.seed = format_seed(seed) - self.include_self = include_self - self.allow_multi_conn = allow_multi_conn - self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed) - - def __repr__(self): - return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' - - -class FixedPreNum(FixedNum): - """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. - - Parameters - ---------- - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - """ - - def build_coo(self): - if isinstance(self.num, int) and self.num > self.pre_num: - raise ConnectorError(f'"num" must be smaller than "pre_num", ' - f'but got {self.num} > {self.pre_num}') - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num - pre_num_total = self.pre_num - post_num_total = self.post_num - - if self.allow_multi_conn: - selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self.rng.randint(0, int(1e8))) - else: - rng = self.rng - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((post_num_total, pre_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(post_num_total): - posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) - return posts - - selected_pre_ids = jnp.asarray(single_conn()) - - post_nums = jnp.ones((post_num_total,), dtype=get_idx_type()) * pre_num_to_select - if not self.include_self: - true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1)) - post_nums -= jnp.sum(true_ids, axis=1) - selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_pre_ids = selected_pre_ids.flatten() - selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums) - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - -class FixedPostNum(FixedNum): - """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. - - Parameters - ---------- - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - """ - - def _ii(self): - if isinstance(self.num, int) and self.num > self.post_num: - raise ConnectorError(f'"num" must be smaller than "post_num", ' - f'but got {self.num} > {self.post_num}') - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num - pre_num_to_select = self.pre_num - pre_ids = jnp.arange(self.pre_num) - post_num_total = self.post_num - - if self.allow_multi_conn: - selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self.rng.randint(0, int(1e8))) - else: - rng = self.rng - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(pre_num_to_select): - posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) - return posts - - selected_post_ids = jnp.asarray(single_conn()) - return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) - - def build_coo(self): - _, post_num_to_select, selected_post_ids, pre_ids = self._ii() - selected_post_ids = selected_post_ids.flatten() - selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) - if not self.include_self: - true_ids = selected_pre_ids != selected_post_ids - selected_pre_ids = selected_pre_ids[true_ids] - selected_post_ids = selected_post_ids[true_ids] - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def build_csr(self): - pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() - pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select - if not self.include_self: - true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) - pre_nums -= jnp.sum(true_ids, axis=1) - selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_post_ids = selected_post_ids.flatten() - selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) - return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) - -@jit -@partial(vmap, in_axes=(0, None, None)) -def gaussian_prob_dist_cal1(i_value, post_values, sigma): - dists = jnp.abs(i_value - post_values) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) - -@jit -@partial(vmap, in_axes=(0, None, None, None)) -def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma): - dists = jnp.abs(i_value - post_values) - dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) - - -class GaussianProb(OneEndConnector): - r"""Builds a Gaussian connectivity pattern within a population of neurons, - where the connection probability decay according to the gaussian function. - - Specifically, for any pair of neurons :math:`(i, j)`, - - .. math:: - - p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) - - where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. - - Parameters - ---------- - sigma : float - Width of the Gaussian function. - encoding_values : optional, list, tuple, int, float - The value ranges to encode for neurons at each axis. - - - If `values` is not provided, the neuron only encodes each positional - information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is - the index in the high-dimensional space. - - If `values` is a single tuple/list of int/float, neurons at each dimension - will encode the same range of values. For example, ``values=(0, np.pi)``, - neurons at each dimension will encode a continuous value space ``[0, np.pi]``. - - If `values` is a tuple/list of list/tuple, it means the value space will be - different for each dimension. For example, ``values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))``. - - periodic_boundary : bool - Whether the neuron encode the value space with the periodic boundary. - normalize : bool - Whether normalize the connection probability . - include_self : bool - Whether create the connection at the same position. - seed : int - The random seed. - """ - - def __init__( - self, - sigma: float, - encoding_values: Optional[np.ndarray] = None, - normalize: bool = True, - include_self: bool = True, - periodic_boundary: bool = False, - seed: int = None, - **kwargs - ): - super(GaussianProb, self).__init__(**kwargs) - self.sigma = sigma - self.encoding_values = encoding_values - self.normalize = normalize - self.include_self = include_self - self.periodic_boundary = periodic_boundary - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - - def __repr__(self): - return (f'{self.__class__.__name__}(sigma={self.sigma}, ' - f'normalize={self.normalize}, ' - f'periodic_boundary={self.periodic_boundary}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_mat(self, isOptimized=True): - self.rng = np.random.RandomState(self.seed) - # value range to encode - if self.encoding_values is None: - value_ranges = tuple([(0, s) for s in self.pre_size]) - elif isinstance(self.encoding_values, (tuple, list)): - if len(self.encoding_values) == 0: - raise ConnectorError(f'encoding_values has a length of 0.') - elif isinstance(self.encoding_values[0], (int, float)): - assert len(self.encoding_values) == 2 - assert self.encoding_values[0] < self.encoding_values[1] - value_ranges = tuple([self.encoding_values for _ in self.pre_size]) - elif isinstance(self.encoding_values[0], (tuple, list)): - if len(self.encoding_values) != len(self.pre_size): - raise ConnectorError(f'The network size has {len(self.pre_size)} dimensions, while ' - f'the encoded values provided only has {len(self.encoding_values)}-D. ' - f'Error in {str(self)}.') - for v in self.encoding_values: - assert isinstance(v[0], (int, float)) - assert len(v) == 2 - value_ranges = tuple(self.encoding_values) - else: - raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') - else: - raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') - - # values - values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] - # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) - post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) - value_sizes = np.array([v[1] - v[0] for v in value_ranges]) - if value_sizes.ndim < post_values.ndim: - value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - - # probability of connections - if isOptimized: - i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) - for i in range(self.pre_num): - list_index = i - # values for node i - i_coordinate = tuple() - for s in self.pre_size[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - i_value_list[list_index] = i_value - - if self.periodic_boundary: - prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) - else: - prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) - else: - prob_mat = [] - for i in range(self.pre_num): - # values for node i - i_coordinate = tuple() - for s in self.pre_size[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - # distances - dists = np.abs(i_value - post_values) - if self.periodic_boundary: - dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) - exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) - prob_mat.append(exp_dists) - prob_mat = np.stack(prob_mat) - - if self.normalize: - prob_mat /= prob_mat.max() - - # connectivity - conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) - if not self.include_self: - np.fill_diagonal(conn_mat, False) - return conn_mat - - -class SmallWorld(TwoEndConnector): - """Build a Watts–Strogatz small-world graph. - - Parameters - ---------- - num_neighbor : int - Each node is joined with its `k` nearest neighbors in a ring - topology. - prob : float - The probability of rewiring each edge - directed : bool - Whether the graph is a directed graph. - include_self : bool - Whether include the node self. - - Notes - ----- - First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is - joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors - if :math:`num\_neighbor` is odd). Then shortcuts are created by replacing some edges as - follows: for each edge :math:`(u, v)` in the underlying ":math:`num\_node`-ring with - :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new - edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. - - References - ---------- - .. [1] Duncan J. Watts and Steven H. Strogatz, - Collective dynamics of small-world networks, - Nature, 393, pp. 440--442, 1998. - """ - - def __init__( - self, - num_neighbor, - prob, - directed=False, - include_self=False, - seed=None, - **kwargs - ): - super(SmallWorld, self).__init__(**kwargs) - self.prob = prob - self.directed = directed - self.num_neighbor = num_neighbor - self.include_self = include_self - - self.seed = format_seed(seed) - self.rng = np.random.RandomState(seed=self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _smallworld_rewire(i, all_j): - if rng.random(1) < prob: - non_connected = np.where(np.logical_not(all_j))[0] - if len(non_connected) <= 1: - return -1 - # Enforce no self-loops or multiple edges - w = rng.choice(non_connected) - while (not include_self) and w == i: - # non_connected.remove(w) - w = rng.choice(non_connected) - return w - else: - return -1 - - self._connect = numba_jit(_smallworld_rewire) - - def __repr__(self): - return (f'{self.__class__.__name__}(prob={self.prob}, ' - f'directed={self.directed}, ' - f'num_neighbor={self.num_neighbor}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_conn(self): - assert self.pre_size == self.post_size - - # seed - self.seed = self.rng.randint(1, int(1e7)) - numba_seed(self.seed) - - if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1): - num_node = self.pre_num - - if self.num_neighbor > num_node: - raise ConnectorError("num_neighbor > num_node, choose smaller num_neighbor or larger num_node") - # If k == n, the graph is complete not Watts-Strogatz - if self.num_neighbor == num_node: - conn = np.ones((num_node, num_node), dtype=MAT_DTYPE) - else: - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - nodes = np.array(list(range(num_node))) # nodes are labeled 0 to n-1 - # connect each node to k/2 neighbors - for j in range(1, self.num_neighbor // 2 + 1): - targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list - conn[nodes, targets] = True - conn[targets, nodes] = True - - # rewire edges from each node - # loop over all nodes in order (label) and neighbors in order (distance) - # no self loops or multiple edges allowed - for j in range(1, self.num_neighbor // 2 + 1): # outer loop is neighbors - targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list - if self.directed: - # inner loop in node order - for u, v in zip(nodes, targets): - w = self._connect(prob=self.prob, i=u, all_j=conn[u]) - if w != -1: - conn[u, v] = False - conn[u, w] = True - w = self._connect(prob=self.prob, i=u, all_j=conn[:, u]) - if w != -1: - conn[v, u] = False - conn[w, u] = True - else: - # inner loop in node order - for u, v in zip(nodes, targets): - w = self._connect(i=u, all_j=conn[u]) - if w != -1: - conn[u, v] = False - conn[v, u] = False - conn[u, w] = True - conn[w, u] = True - # conn = np.asarray(conn, dtype=MAT_DTYPE) - else: - raise ConnectorError('Currently only support 1D ring connection.') - - return 'mat', conn - - -# def _random_subset(seq, m, rng): -# """Return m unique elements from seq. -# -# This differs from random.sample which can return repeated -# elements if seq holds repeated elements. -# -# Note: rng is a random.Random or numpy.random.RandomState instance. -# """ -# targets = set() -# while len(targets) < m: -# x = rng.choice(seq) -# targets.add(x) -# return targets - - -class ScaleFreeBA(TwoEndConnector): - """Build a random graph according to the Barabási–Albert preferential - attachment model. - - A graph of :math:`num\_node` nodes is grown by attaching new nodes each with - :math:`m` edges that are preferentially attached to existing nodes - with high degree. - - Parameters - ---------- - m : int - Number of edges to attach from a new node to existing nodes - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Raises - ------ - ConnectorError - If `m` does not satisfy ``1 <= m < n``. - - References - ---------- - .. [1] A. L. Barabási and R. Albert "Emergence of scaling in - random networks", Science 286, pp 509-512, 1999. - """ - - def __init__(self, m, directed=False, seed=None, **kwargs): - super(ScaleFreeBA, self).__init__(**kwargs) - self.m = m - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m={self.m}, ' - f'directed={self.directed}, ' - f'seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - - num_node = self.pre_num - if self.m < 1 or self.m >= num_node: - raise ConnectorError(f"Barabási–Albert network must have m >= 1 and " - f"m < n, while m = {self.m} and n = {num_node}") - - # Add m initial nodes (m0 in barabasi-speak) - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - # Target nodes for new edges - targets = list(range(self.m)) - # List of existing nodes, with nodes repeated once for each adjacent edge - - if not isOptimized: - repeated_nodes = [] - # Start adding the other n-m nodes. The first node is m. - source = self.m - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * self.m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes.extend(targets) - # And the new node "source" has m edges to add to the list. - repeated_nodes.extend([source] * self.m) - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(np.asarray(repeated_nodes), self.m)) - source += 1 - return conn - - # List of existing nodes, with nodes repeated once for each adjacent edge - # Preallocate repeated_nodes as a numpy array - repeated_nodes = np.empty(2 * num_node * self.m, dtype=int) - size_repeated_nodes = 0 - # Start adding the other n-m nodes. The first node is m. - source = self.m - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * self.m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets - size_repeated_nodes += self.m - # And the new node "source" has m edges to add to the list. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source - size_repeated_nodes += self.m - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m)) - source += 1 - - return conn - - -class ScaleFreeBADual(TwoEndConnector): - r"""Build a random graph according to the dual Barabási–Albert preferential - attachment model. - - A graph of :math::`num\_node` nodes is grown by attaching new nodes each with either $m_1$ - edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that - are preferentially attached to existing nodes with high degree. - - Parameters - ---------- - m1 : int - Number of edges to attach from a new node to existing nodes with probability :math:`p` - m2 : int - Number of edges to attach from a new node to existing nodes with probability :math:`1-p` - p : float - The probability of attaching :math:`m\_1` edges (as opposed to :math:`m\_2` edges) - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Raises - ------ - ConnectorError - If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. - - References - ---------- - .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. - """ - - def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs): - super(ScaleFreeBADual, self).__init__(**kwargs) - self.m1 = m1 - self.m2 = m2 - self.p = p - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, ' - f'p={self.p}, directed={self.directed}, seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - - num_node = self.pre_num - if self.m1 < 1 or self.m1 >= num_node: - raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, " - f"while m1 = {self.m1} and num_node = {num_node}.") - if self.m2 < 1 or self.m2 >= num_node: - raise ConnectorError(f"Dual Barabási–Albert network must have m2 >= 1 and m2 < num_node, " - f"while m2 = {self.m2} and num_node = {num_node}.") - if self.p < 0 or self.p > 1: - raise ConnectorError(f"Dual Barabási–Albert network must have 0 <= p <= 1, while p = {self.p}") - - # Add max(m1,m2) initial nodes (m0 in barabasi-speak) - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - - if not isOptimized: - # List of existing nodes, with nodes repeated once for each adjacent edge - repeated_nodes = [] - # Start adding the remaining nodes. - source = max(self.m1, self.m2) - # Pick which m to use first time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Target nodes for new edges - targets = list(range(m)) - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes.extend(targets) - # And the new node "source" has m edges to add to the list. - repeated_nodes.extend([source] * m) - # Pick which m to use next time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(np.asarray(repeated_nodes), m)) - source += 1 - return conn - - # List of existing nodes, with nodes repeated once for each adjacent edge - # Preallocate repeated_nodes as a numpy array - repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int) - size_repeated_nodes = 0 - # Start adding the remaining nodes. - source = max(self.m1, self.m2) - # Pick which m to use first time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Target nodes for new edges - targets = list(range(m)) - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets - size_repeated_nodes += m - # And the new node "source" has m edges to add to the list. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source - size_repeated_nodes += m - # Pick which m to use next time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m)) - source += 1 - - return conn - - -class PowerLaw(TwoEndConnector): - """Holme and Kim algorithm for growing graphs with powerlaw - degree distribution and approximate average clustering. - - Parameters - ---------- - m : int - the number of random edges to add for each new node - p : float, - Probability of adding a triangle after adding a random edge - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Notes - ----- - The average clustering has a hard time getting above a certain - cutoff that depends on :math:`m`. This cutoff is often quite low. The - transitivity (fraction of triangles to possible triangles) seems to - decrease with network size. - - It is essentially the Barabási–Albert (BA) growth model with an - extra step that each random edge is followed by a chance of - making an edge to one of its neighbors too (and thus a triangle). - - This algorithm improves on BA in the sense that it enables a - higher average clustering to be attained if desired. - - It seems possible to have a disconnected graph with this algorithm - since the initial :math:`m` nodes may not be all linked to a new node - on the first iteration like the BA model. - - Raises - ------ - ConnectorError - If :math:`m` does not satisfy :math:`1 <= m <= n` or :math:`p` does not - satisfy :math:`0 <= p <= 1`. - - References - ---------- - .. [1] P. Holme and B. J. Kim, - "Growing scale-free networks with tunable clustering", - Phys. Rev. E, 65, 026107, 2002. - """ - - def __init__(self, m: int, p: float, directed=False, seed=None, **kwargs): - super(PowerLaw, self).__init__(**kwargs) - self.m = m - self.p = p - if self.p > 1 or self.p < 0: - raise ConnectorError(f"p must be in [0,1], while p={self.p}") - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - num_node = self.pre_num - if self.m < 1 or num_node < self.m: - raise ConnectorError(f"Must have m>1 and m 1 else p.flatten() for p in pre_ids]) - size = np.prod(pre_size) - - for i in range(size): - pre_pos = np.asarray([p[i] for p in pre_ids]) - pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim) - connected_pres.extend(pres) - connected_posts.extend(posts) - return np.asarray(connected_pres), np.asarray(connected_posts) +# -*- coding: utf-8 -*- + +from functools import partial +from typing import Optional + +from jax import vmap, jit, numpy as jnp +import numpy as np + +import brainpy.math as bm +from brainpy.errors import ConnectorError +from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed +from brainpy._src.tools.package import SUPPORT_NUMBA +from .base import * + + +__all__ = [ + 'FixedProb', + 'FixedPreNum', + 'FixedPostNum', + 'FixedTotalNum', + 'GaussianProb', + 'ProbDist', + + 'SmallWorld', + 'ScaleFreeBA', + 'ScaleFreeBADual', + 'PowerLaw', +] + + +class FixedProb(TwoEndConnector): + """Connect the post-synaptic neurons with fixed probability. + + Parameters + ---------- + prob: float + The conn probability. + pre_ratio: float + The ratio of pre-synaptic neurons to connect. + include_self : bool + Whether create (i, i) conn? + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + seed : optional, int + Seed the random generator. + """ + + def __init__(self, + prob, + pre_ratio=1., + include_self=True, + allow_multi_conn=False, + seed=None, + **kwargs): + super(FixedProb, self).__init__(**kwargs) + assert 0. <= prob <= 1. + assert 0. <= pre_ratio <= 1. + self.prob = prob + self.pre_ratio = pre_ratio + self.include_self = include_self + self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn + self._jaxrand = bm.random.default_rng(self.seed) + self._nprand = np.random.RandomState(self.seed) + + def __repr__(self): + return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' + f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' + f'seed={self.seed})') + + def _iii(self): + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + + if self.pre_ratio < 1.: + pre_num_to_select = int(self.pre_num * self.pre_ratio) + pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) + else: + pre_num_to_select = self.pre_num + pre_ids = jnp.arange(self.pre_num) + + post_num_total = self.post_num + post_num_to_select = int(self.post_num * self.prob) + + if self.allow_multi_conn: + selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self._nprand.randint(0, int(1e8))) + else: + rng = self._nprand + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = jnp.asarray(single_conn()) + return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) + + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._iii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() + pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) + pre_nums -= jnp.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) + return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) + + def build_mat(self): + if self.pre_ratio < 1.: + pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state + else: + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) + mat = bm.asarray(mat) + if not self.include_self: + bm.fill_diagonal(mat, False) + return mat.astype(MAT_DTYPE) + + +class FixedTotalNum(TwoEndConnector): + """Connect the synaptic neurons with fixed total number. + + Parameters + ---------- + num : float,int + The conn total number. + allow_multi_conn : bool, optional + Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. + seed: int, optional + The random number seed. + """ + + def __init__(self, + num, + allow_multi_conn=False, + seed=None, **kwargs): + super().__init__(**kwargs) + if isinstance(num, int): + assert num >= 0, '"num" must be a non-negative integer.' + elif isinstance(num, float): + assert 0. <= num <= 1., '"num" must be in [0., 1.).' + else: + raise ConnectorError(f'Unknown type: {type(num)}') + self.num = num + self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn + self.rng = bm.random.RandomState(self.seed) + + def build_coo(self): + mat_element_num = self.pre_num * self.post_num + if self.num > mat_element_num: + raise ConnectorError(f'"num" must be smaller than "all2all num", ' + f'but got {self.num} > {mat_element_num}') + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) + selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) + else: + index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) + selected_pre_ids = index // self.post_num + selected_post_ids = index % self.post_num + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def __repr__(self): + return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})' + + +class FixedNum(TwoEndConnector): + def __init__(self, + num, + include_self=True, + allow_multi_conn=False, + seed=None, + **kwargs): + super(FixedNum, self).__init__(**kwargs) + if isinstance(num, int): + assert num >= 0, '"num" must be a non-negative integer.' + elif isinstance(num, float): + assert 0. <= num <= 1., '"num" must be in [0., 1.).' + else: + raise ConnectorError(f'Unknown type: {type(num)}') + self.num = num + self.seed = format_seed(seed) + self.include_self = include_self + self.allow_multi_conn = allow_multi_conn + self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed) + + def __repr__(self): + return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' + + +class FixedPreNum(FixedNum): + """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. + + Parameters + ---------- + num : float, int + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). + include_self : bool + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + """ + + def build_coo(self): + if isinstance(self.num, int) and self.num > self.pre_num: + raise ConnectorError(f'"num" must be smaller than "pre_num", ' + f'but got {self.num} > {self.pre_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num + pre_num_total = self.pre_num + post_num_total = self.post_num + + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((post_num_total, pre_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(post_num_total): + posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) + return posts + + selected_pre_ids = jnp.asarray(single_conn()) + + post_nums = jnp.ones((post_num_total,), dtype=get_idx_type()) * pre_num_to_select + if not self.include_self: + true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1)) + post_nums -= jnp.sum(true_ids, axis=1) + selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_pre_ids = selected_pre_ids.flatten() + selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums) + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + +class FixedPostNum(FixedNum): + """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. + + Parameters + ---------- + num : float, int + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). + include_self : bool + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + """ + + def _ii(self): + if isinstance(self.num, int) and self.num > self.post_num: + raise ConnectorError(f'"num" must be smaller than "post_num", ' + f'but got {self.num} > {self.post_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num + pre_num_to_select = self.pre_num + pre_ids = jnp.arange(self.pre_num) + post_num_total = self.post_num + + if self.allow_multi_conn: + selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = jnp.asarray(single_conn()) + return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) + + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._ii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() + pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) + pre_nums -= jnp.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) + return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) + + +@jit +@partial(vmap, in_axes=(0, None, None)) +def gaussian_prob_dist_cal1(i_value, post_values, sigma): + dists = jnp.abs(i_value - post_values) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) + + +@jit +@partial(vmap, in_axes=(0, None, None, None)) +def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma): + dists = jnp.abs(i_value - post_values) + dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) + + +class GaussianProb(OneEndConnector): + r"""Builds a Gaussian connectivity pattern within a population of neurons, + where the connection probability decay according to the gaussian function. + + Specifically, for any pair of neurons :math:`(i, j)`, + + .. math:: + + p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) + + where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. + + Parameters + ---------- + sigma : float + Width of the Gaussian function. + encoding_values : optional, list, tuple, int, float + The value ranges to encode for neurons at each axis. + + - If `values` is not provided, the neuron only encodes each positional + information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is + the index in the high-dimensional space. + - If `values` is a single tuple/list of int/float, neurons at each dimension + will encode the same range of values. For example, ``values=(0, np.pi)``, + neurons at each dimension will encode a continuous value space ``[0, np.pi]``. + - If `values` is a tuple/list of list/tuple, it means the value space will be + different for each dimension. For example, ``values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))``. + + periodic_boundary : bool + Whether the neuron encode the value space with the periodic boundary. + normalize : bool + Whether normalize the connection probability . + include_self : bool + Whether create the connection at the same position. + seed : int + The random seed. + """ + + def __init__( + self, + sigma: float, + encoding_values: Optional[np.ndarray] = None, + normalize: bool = True, + include_self: bool = True, + periodic_boundary: bool = False, + seed: int = None, + **kwargs + ): + super(GaussianProb, self).__init__(**kwargs) + self.sigma = sigma + self.encoding_values = encoding_values + self.normalize = normalize + self.include_self = include_self + self.periodic_boundary = periodic_boundary + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + + def __repr__(self): + return (f'{self.__class__.__name__}(sigma={self.sigma}, ' + f'normalize={self.normalize}, ' + f'periodic_boundary={self.periodic_boundary}, ' + f'include_self={self.include_self}, ' + f'seed={self.seed})') + + def build_mat(self, isOptimized=True): + self.rng = np.random.RandomState(self.seed) + # value range to encode + if self.encoding_values is None: + value_ranges = tuple([(0, s) for s in self.pre_size]) + elif isinstance(self.encoding_values, (tuple, list)): + if len(self.encoding_values) == 0: + raise ConnectorError(f'encoding_values has a length of 0.') + elif isinstance(self.encoding_values[0], (int, float)): + assert len(self.encoding_values) == 2 + assert self.encoding_values[0] < self.encoding_values[1] + value_ranges = tuple([self.encoding_values for _ in self.pre_size]) + elif isinstance(self.encoding_values[0], (tuple, list)): + if len(self.encoding_values) != len(self.pre_size): + raise ConnectorError(f'The network size has {len(self.pre_size)} dimensions, while ' + f'the encoded values provided only has {len(self.encoding_values)}-D. ' + f'Error in {str(self)}.') + for v in self.encoding_values: + assert isinstance(v[0], (int, float)) + assert len(v) == 2 + value_ranges = tuple(self.encoding_values) + else: + raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') + else: + raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') + + # values + values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] + # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) + post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) + value_sizes = np.array([v[1] - v[0] for v in value_ranges]) + if value_sizes.ndim < post_values.ndim: + value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + + # probability of connections + if isOptimized: + i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) + for i in range(self.pre_num): + list_index = i + # values for node i + i_coordinate = tuple() + for s in self.pre_size[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + i_value_list[list_index] = i_value + + if self.periodic_boundary: + prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) + else: + prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) + else: + prob_mat = [] + for i in range(self.pre_num): + # values for node i + i_coordinate = tuple() + for s in self.pre_size[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + # distances + dists = np.abs(i_value - post_values) + if self.periodic_boundary: + dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) + exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) + prob_mat.append(exp_dists) + prob_mat = np.stack(prob_mat) + + if self.normalize: + prob_mat /= prob_mat.max() + + # connectivity + conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) + if not self.include_self: + np.fill_diagonal(conn_mat, False) + return conn_mat + + +class SmallWorld(TwoEndConnector): + """Build a Watts–Strogatz small-world graph. + + Parameters + ---------- + num_neighbor : int + Each node is joined with its `k` nearest neighbors in a ring + topology. + prob : float + The probability of rewiring each edge + directed : bool + Whether the graph is a directed graph. + include_self : bool + Whether include the node self. + + Notes + ----- + First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is + joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors + if :math:`num\_neighbor` is odd). Then shortcuts are created by replacing some edges as + follows: for each edge :math:`(u, v)` in the underlying ":math:`num\_node`-ring with + :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new + edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. + + References + ---------- + .. [1] Duncan J. Watts and Steven H. Strogatz, + Collective dynamics of small-world networks, + Nature, 393, pp. 440--442, 1998. + """ + + def __init__( + self, + num_neighbor, + prob, + directed=False, + include_self=False, + seed=None, + **kwargs + ): + super(SmallWorld, self).__init__(**kwargs) + self.prob = prob + self.directed = directed + self.num_neighbor = num_neighbor + self.include_self = include_self + + self.seed = format_seed(seed) + self.rng = np.random.RandomState(seed=self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _smallworld_rewire(i, all_j): + if rng.random(1) < prob: + non_connected = np.where(np.logical_not(all_j))[0] + if len(non_connected) <= 1: + return -1 + # Enforce no self-loops or multiple edges + w = rng.choice(non_connected) + while (not include_self) and w == i: + # non_connected.remove(w) + w = rng.choice(non_connected) + return w + else: + return -1 + + self._connect = numba_jit(_smallworld_rewire) + + def __repr__(self): + return (f'{self.__class__.__name__}(prob={self.prob}, ' + f'directed={self.directed}, ' + f'num_neighbor={self.num_neighbor}, ' + f'include_self={self.include_self}, ' + f'seed={self.seed})') + + def build_conn(self): + assert self.pre_size == self.post_size + + # seed + self.seed = self.rng.randint(1, int(1e7)) + numba_seed(self.seed) + + if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1): + num_node = self.pre_num + + if self.num_neighbor > num_node: + raise ConnectorError("num_neighbor > num_node, choose smaller num_neighbor or larger num_node") + # If k == n, the graph is complete not Watts-Strogatz + if self.num_neighbor == num_node: + conn = np.ones((num_node, num_node), dtype=MAT_DTYPE) + else: + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + nodes = np.array(list(range(num_node))) # nodes are labeled 0 to n-1 + # connect each node to k/2 neighbors + for j in range(1, self.num_neighbor // 2 + 1): + targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list + conn[nodes, targets] = True + conn[targets, nodes] = True + + # rewire edges from each node + # loop over all nodes in order (label) and neighbors in order (distance) + # no self loops or multiple edges allowed + for j in range(1, self.num_neighbor // 2 + 1): # outer loop is neighbors + targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list + if self.directed: + # inner loop in node order + for u, v in zip(nodes, targets): + w = self._connect(prob=self.prob, i=u, all_j=conn[u]) + if w != -1: + conn[u, v] = False + conn[u, w] = True + w = self._connect(prob=self.prob, i=u, all_j=conn[:, u]) + if w != -1: + conn[v, u] = False + conn[w, u] = True + else: + # inner loop in node order + for u, v in zip(nodes, targets): + w = self._connect(i=u, all_j=conn[u]) + if w != -1: + conn[u, v] = False + conn[v, u] = False + conn[u, w] = True + conn[w, u] = True + # conn = np.asarray(conn, dtype=MAT_DTYPE) + else: + raise ConnectorError('Currently only support 1D ring connection.') + + return 'mat', conn + + +# def _random_subset(seq, m, rng): +# """Return m unique elements from seq. +# +# This differs from random.sample which can return repeated +# elements if seq holds repeated elements. +# +# Note: rng is a random.Random or numpy.random.RandomState instance. +# """ +# targets = set() +# while len(targets) < m: +# x = rng.choice(seq) +# targets.add(x) +# return targets + + +class ScaleFreeBA(TwoEndConnector): + """Build a random graph according to the Barabási–Albert preferential + attachment model. + + A graph of :math:`num\_node` nodes is grown by attaching new nodes each with + :math:`m` edges that are preferentially attached to existing nodes + with high degree. + + Parameters + ---------- + m : int + Number of edges to attach from a new node to existing nodes + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Raises + ------ + ConnectorError + If `m` does not satisfy ``1 <= m < n``. + + References + ---------- + .. [1] A. L. Barabási and R. Albert "Emergence of scaling in + random networks", Science 286, pp 509-512, 1999. + """ + + def __init__(self, m, directed=False, seed=None, **kwargs): + super(ScaleFreeBA, self).__init__(**kwargs) + self.m = m + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m={self.m}, ' + f'directed={self.directed}, ' + f'seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + num_node = self.pre_num + if self.m < 1 or self.m >= num_node: + raise ConnectorError(f"Barabási–Albert network must have m >= 1 and " + f"m < n, while m = {self.m} and n = {num_node}") + + # Add m initial nodes (m0 in barabasi-speak) + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + # Target nodes for new edges + targets = list(range(self.m)) + # List of existing nodes, with nodes repeated once for each adjacent edge + + if not isOptimized: + repeated_nodes = [] + # Start adding the other n-m nodes. The first node is m. + source = self.m + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * self.m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes.extend(targets) + # And the new node "source" has m edges to add to the list. + repeated_nodes.extend([source] * self.m) + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(np.asarray(repeated_nodes), self.m)) + source += 1 + return conn + + # List of existing nodes, with nodes repeated once for each adjacent edge + # Preallocate repeated_nodes as a numpy array + repeated_nodes = np.empty(2 * num_node * self.m, dtype=int) + size_repeated_nodes = 0 + # Start adding the other n-m nodes. The first node is m. + source = self.m + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * self.m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets + size_repeated_nodes += self.m + # And the new node "source" has m edges to add to the list. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source + size_repeated_nodes += self.m + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m)) + source += 1 + + return conn + + +class ScaleFreeBADual(TwoEndConnector): + r"""Build a random graph according to the dual Barabási–Albert preferential + attachment model. + + A graph of :math::`num\_node` nodes is grown by attaching new nodes each with either $m_1$ + edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that + are preferentially attached to existing nodes with high degree. + + Parameters + ---------- + m1 : int + Number of edges to attach from a new node to existing nodes with probability :math:`p` + m2 : int + Number of edges to attach from a new node to existing nodes with probability :math:`1-p` + p : float + The probability of attaching :math:`m\_1` edges (as opposed to :math:`m\_2` edges) + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Raises + ------ + ConnectorError + If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. + + References + ---------- + .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. + """ + + def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs): + super(ScaleFreeBADual, self).__init__(**kwargs) + self.m1 = m1 + self.m2 = m2 + self.p = p + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, ' + f'p={self.p}, directed={self.directed}, seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + num_node = self.pre_num + if self.m1 < 1 or self.m1 >= num_node: + raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, " + f"while m1 = {self.m1} and num_node = {num_node}.") + if self.m2 < 1 or self.m2 >= num_node: + raise ConnectorError(f"Dual Barabási–Albert network must have m2 >= 1 and m2 < num_node, " + f"while m2 = {self.m2} and num_node = {num_node}.") + if self.p < 0 or self.p > 1: + raise ConnectorError(f"Dual Barabási–Albert network must have 0 <= p <= 1, while p = {self.p}") + + # Add max(m1,m2) initial nodes (m0 in barabasi-speak) + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + + if not isOptimized: + # List of existing nodes, with nodes repeated once for each adjacent edge + repeated_nodes = [] + # Start adding the remaining nodes. + source = max(self.m1, self.m2) + # Pick which m to use first time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Target nodes for new edges + targets = list(range(m)) + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes.extend(targets) + # And the new node "source" has m edges to add to the list. + repeated_nodes.extend([source] * m) + # Pick which m to use next time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(np.asarray(repeated_nodes), m)) + source += 1 + return conn + + # List of existing nodes, with nodes repeated once for each adjacent edge + # Preallocate repeated_nodes as a numpy array + repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int) + size_repeated_nodes = 0 + # Start adding the remaining nodes. + source = max(self.m1, self.m2) + # Pick which m to use first time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Target nodes for new edges + targets = list(range(m)) + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets + size_repeated_nodes += m + # And the new node "source" has m edges to add to the list. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source + size_repeated_nodes += m + # Pick which m to use next time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m)) + source += 1 + + return conn + + +class PowerLaw(TwoEndConnector): + """Holme and Kim algorithm for growing graphs with powerlaw + degree distribution and approximate average clustering. + + Parameters + ---------- + m : int + the number of random edges to add for each new node + p : float, + Probability of adding a triangle after adding a random edge + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Notes + ----- + The average clustering has a hard time getting above a certain + cutoff that depends on :math:`m`. This cutoff is often quite low. The + transitivity (fraction of triangles to possible triangles) seems to + decrease with network size. + + It is essentially the Barabási–Albert (BA) growth model with an + extra step that each random edge is followed by a chance of + making an edge to one of its neighbors too (and thus a triangle). + + This algorithm improves on BA in the sense that it enables a + higher average clustering to be attained if desired. + + It seems possible to have a disconnected graph with this algorithm + since the initial :math:`m` nodes may not be all linked to a new node + on the first iteration like the BA model. + + Raises + ------ + ConnectorError + If :math:`m` does not satisfy :math:`1 <= m <= n` or :math:`p` does not + satisfy :math:`0 <= p <= 1`. + + References + ---------- + .. [1] P. Holme and B. J. Kim, + "Growing scale-free networks with tunable clustering", + Phys. Rev. E, 65, 026107, 2002. + """ + + def __init__(self, m: int, p: float, directed=False, seed=None, **kwargs): + super(PowerLaw, self).__init__(**kwargs) + self.m = m + self.p = p + if self.p > 1 or self.p < 0: + raise ConnectorError(f"p must be in [0,1], while p={self.p}") + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + num_node = self.pre_num + if self.m < 1 or num_node < self.m: + raise ConnectorError(f"Must have m>1 and m 1 else p.flatten() for p in pre_ids]) + size = np.prod(pre_size) + + for i in range(size): + pre_pos = np.asarray([p[i] for p in pre_ids]) + pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim) + connected_pres.extend(pres) + connected_posts.extend(posts) + return np.asarray(connected_pres), np.asarray(connected_posts) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 3bba20a79..b8bd6e99a 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,87 +1,135 @@ -import os -import sys -from jax.lib import xla_client - -__all__ = [ - 'import_taichi', - 'import_brainpylib_cpu_ops', - 'import_brainpylib_gpu_ops', -] - -_minimal_brainpylib_version = '0.2.6' -_minimal_taichi_version = (1, 7, 0) - -taichi = None -brainpylib_cpu_ops = None -brainpylib_gpu_ops = None - -taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' - f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' - '> pip install taichi==1.7.0') -os.environ["TI_LOG_LEVEL"] = "error" - - -def import_taichi(): - global taichi - if taichi is None: - with open(os.devnull, 'w') as devnull: - old_stdout = sys.stdout - sys.stdout = devnull - try: - import taichi as taichi # noqa - except ModuleNotFoundError: - raise ModuleNotFoundError(taichi_install_info) - finally: - sys.stdout = old_stdout - - if taichi.__version__ != _minimal_taichi_version: - raise RuntimeError(taichi_install_info) - return taichi - - -def is_brainpylib_gpu_installed(): - return False if brainpylib_gpu_ops is None else True - - -def import_brainpylib_cpu_ops(): - global brainpylib_cpu_ops - if brainpylib_cpu_ops is None: - try: - from brainpylib import cpu_ops as brainpylib_cpu_ops - - for _name, _value in brainpylib_cpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="cpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_cpu_ops - - -def import_brainpylib_gpu_ops(): - global brainpylib_gpu_ops - if brainpylib_gpu_ops is None: - try: - from brainpylib import gpu_ops as brainpylib_gpu_ops - - for _name, _value in brainpylib_gpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="gpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install GPU version of brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_gpu_ops +import os +import sys + +from jax.lib import xla_client + +__all__ = [ + 'import_taichi', + 'raise_taichi_not_found', + 'import_numba', + 'raise_numba_not_found', + 'import_brainpylib_cpu_ops', + 'import_brainpylib_gpu_ops', +] + +_minimal_brainpylib_version = '0.2.6' +_minimal_taichi_version = (1, 7, 0) + +numba = None +taichi = None +brainpylib_cpu_ops = None +brainpylib_gpu_ops = None + +taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' + f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' + '> pip install taichi==1.7.0') +numba_install_info = ('We need numba. Please install numba by pip . \n' + '> pip install numba') +os.environ["TI_LOG_LEVEL"] = "error" + + +def import_taichi(error_if_not_found=True): + """Internal API to import taichi. + + If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global taichi + if taichi is None: + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + sys.stdout = devnull + try: + import taichi as taichi # noqa + except ModuleNotFoundError: + if error_if_not_found: + raise raise_taichi_not_found() + finally: + sys.stdout = old_stdout + + if taichi is None: + return None + if taichi.__version__ != _minimal_taichi_version: + raise RuntimeError(taichi_install_info) + return taichi + + +def raise_taichi_not_found(*args, **kwargs): + raise ModuleNotFoundError(taichi_install_info) + + +def import_numba(error_if_not_found=True): + """ + Internal API to import numba. + + If numba is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global numba + if numba is None: + try: + import numba as numba + except ModuleNotFoundError: + if error_if_not_found: + raise_numba_not_found() + else: + return None + return numba + + +def raise_numba_not_found(): + raise ModuleNotFoundError(numba_install_info) + + +def is_brainpylib_gpu_installed(): + return False if brainpylib_gpu_ops is None else True + + +def import_brainpylib_cpu_ops(): + """ + Internal API to import brainpylib cpu_ops. + """ + global brainpylib_cpu_ops + if brainpylib_cpu_ops is None: + try: + from brainpylib import cpu_ops as brainpylib_cpu_ops + + for _name, _value in brainpylib_cpu_ops.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="cpu") + + import brainpylib + if brainpylib.__version__ < _minimal_brainpylib_version: + raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') + if hasattr(brainpylib, 'check_brainpy_version'): + brainpylib.check_brainpy_version() + + except ImportError: + raise ImportError('Please install brainpylib. \n' + 'See https://brainpy.readthedocs.io for installation instructions.') + + return brainpylib_cpu_ops + + +def import_brainpylib_gpu_ops(): + """ + Internal API to import brainpylib gpu_ops. + """ + global brainpylib_gpu_ops + if brainpylib_gpu_ops is None: + try: + from brainpylib import gpu_ops as brainpylib_gpu_ops + + for _name, _value in brainpylib_gpu_ops.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="gpu") + + import brainpylib + if brainpylib.__version__ < _minimal_brainpylib_version: + raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') + if hasattr(brainpylib, 'check_brainpy_version'): + brainpylib.check_brainpy_version() + + except ImportError: + raise ImportError('Please install GPU version of brainpylib. \n' + 'See https://brainpy.readthedocs.io for installation instructions.') + + return brainpylib_gpu_ops diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index e4b6e25d2..deead1f3b 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -160,7 +160,7 @@ def update(self, x): nonbatching = False if x.ndim == self.num_spatial_dims + 1: nonbatching = True - x = bm.unsqueeze(x, 0) + x = x.unsqueeze(0) w = self.w.value if self.mask is not None: try: @@ -190,9 +190,6 @@ def __repr__(self): class Conv1d(_GeneralConv): """One-dimensional convolution. - The input should a 2d array with the shape of ``[H, C]``, or - a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. - Parameters ---------- in_channels: int @@ -285,9 +282,6 @@ def _check_input_dim(self, x): class Conv2d(_GeneralConv): """Two-dimensional convolution. - The input should a 3d array with the shape of ``[H, W, C]``, or - a 4d array with the shape of ``[B, H, W, C]``. - Parameters ---------- in_channels: int @@ -381,9 +375,6 @@ def _check_input_dim(self, x): class Conv3d(_GeneralConv): """Three-dimensional convolution. - The input should a 3d array with the shape of ``[H, W, D, C]``, or - a 4d array with the shape of ``[B, H, W, D, C]``. - Parameters ---------- in_channels: int diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 539214d3b..c524fb0bf 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -6,22 +6,21 @@ import jax import jax.numpy as jnp -import numba import numpy as np from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share +from brainpy._src.dependency_check import import_taichi from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP -from brainpy._src.dependency_check import import_taichi from brainpy.check import is_initializer from brainpy.connect import csr2csc -from brainpy.errors import MathError +from brainpy.errors import MathError, PackageMissingError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'Dense', 'Linear', @@ -239,140 +238,106 @@ def update(self, x): return x -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): -# out_w[:] = weight -# for i in numba.prange(spike.shape[0]): -# if spike[i]: -# out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) - -@ti.kernel -def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 +if ti is not None: + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) + + @ti.kernel + def _dense_on_post( + old_w: ti.types.ndarray(ndim=2), + post_spike: ti.types.ndarray(ndim=1), + pre_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if post_spike[j]: + new_value = out_w[i, j] + pre_trace[i] if new_value < w_min0: out_w[i, j] = w_min0 elif new_value > w_max0: out_w[i, j] = w_max0 else: - out_w[i, j] = new_value - - -@ti.kernel -def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 + out_w[i, j] = new_value + else: + out_w[i, j] = old_w[i, j] + + + dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post) + + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) + + @ti.kernel + def _dense_on_pre( + old_w: ti.types.ndarray(ndim=2), + pre_spike: ti.types.ndarray(ndim=1), + post_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if pre_spike[i]: + new_value = out_w[i, j] + post_trace[j] if new_value < w_min0: out_w[i, j] = w_min0 elif new_value > w_max0: out_w[i, j] = w_max0 else: out_w[i, j] = new_value - + else: + out_w[i, j] = old_w[i, j] -dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre, - gpu_kernel=_gpu_dense_on_pre) + + dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre) + +else: + dense_on_pre_prim = None + dense_on_post_prim = None def dense_on_pre(weight, spike, trace, w_min, w_max): + if dense_on_pre_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) return dense_on_pre_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): -# out_w[:] = weight -# for i in numba.prange(spike.shape[0]): -# if spike[i]: -# out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) - -@ti.kernel -def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - -@ti.kernel -def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - -dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post, - gpu_kernel=_gpu_dense_on_post) - - def dense_on_post(weight, spike, trace, w_min, w_max): + if dense_on_post_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) return dense_on_post_prim(weight, spike, trace, w_min, w_max, @@ -630,7 +595,7 @@ def stdp_update( raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: # update on presynaptic spike + if on_pre is not None: # update on presynaptic spike spike = on_pre['spike'] trace = on_pre['trace'] self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) @@ -682,8 +647,7 @@ def __init__( def update(self, x): if x.ndim == 1: return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - method=self.method, transpose=self.transpose) + shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) @@ -694,8 +658,8 @@ def update(self, x): def _batch_csrmv(self, x): return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - method=self.method, transpose=self.transpose) + shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) + class EventCSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with event CSR sparse computation. @@ -746,99 +710,170 @@ def _batch_csrmv(self, x): shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): -# out_w[:] = w -# w_min = w_min[()] -# w_max = w_max[()] -# for i in numba.prange(spike.shape[0]): # pre id -# if spike[i]: -# for k in range(indptr[i], indptr[i + 1]): # synapse id -# j = indices[k] # post id -# # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) -# out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) - - -@ti.kernel -def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) -@ti.kernel -def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) - - -csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_cpu_csr_on_pre_update, - gpu_kernel=_gpu_csr_on_pre_update) + +if ti is not None: + @ti.kernel + def _csr_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1) + spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre = spike.shape[0] + for i_pre in range(num_pre): + if spike[i_pre]: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0) + else: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = old_w[i_syn] + + + csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update) + + + @ti.kernel + def _coo_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if pre_spike[pre_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update) + + + @ti.kernel + def _coo_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if post_spike[post_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update) + + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + # out_w[:] = w + # w_min = w_min[()] + # w_max = w_max[()] + # for i in numba.prange(spike.shape[0]): # post id + # if spike[i]: + # for k in range(indptr[i], indptr[i + 1]): + # j = post_ids[k] # pre id + # l = w_ids[k] # syn id + # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) + + @ti.kernel + def _csc_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1) + w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_post = post_spike.shape[0] + for i_post in range(num_post): + if post_spike[i_post]: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0) + else: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = old_w[i_syn] + + + csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update) + + +else: + csr_on_pre_update_prim = None + coo_on_pre_update_prim = None + csc_on_post_update_prim = None def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): + if csr_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # post id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = post_ids[k] # pre id - l = w_ids[k] # syn id - out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) +def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None): + if coo_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') -csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) - - -def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] +def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None): + if csc_on_post_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + class CSCLinear(Layer): r"""Synaptic matrix multiplication with CSC sparse computation. diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index 17054667d..ba2a49efd 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,6 +1,5 @@ -from absl.testing import absltest from absl.testing import parameterized - +from absl.testing import absltest import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 05f523622..3c9fdfa87 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,15 +1,17 @@ # -*- coding: utf-8 -*- -import jax.numpy as jnp +from unittest import TestCase from absl.testing import absltest +import jax.numpy as jnp +import brainpy.math as bm from absl.testing import parameterized - import brainpy as bp import brainpy.math as bm class TestConv(parameterized.TestCase): def test_Conv2D_img(self): + bm.random.seed() img = jnp.zeros((2, 200, 198, 4)) for k in range(4): x = 30 + 60 * k @@ -22,7 +24,6 @@ def test_Conv2D_img(self): strides=(2, 1), padding='VALID', groups=4) out = net(img) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 99, 196, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(img)[0, :, :, 0]) @@ -30,6 +31,7 @@ def test_Conv2D_img(self): bm.clear_buffer_memory() def test_conv1D(self): + bm.random.seed() with bp.math.training_environment(): model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) @@ -37,7 +39,6 @@ def test_conv1D(self): out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :]) @@ -53,7 +54,6 @@ def test_conv2D(self): out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :, 31]) @@ -67,7 +67,6 @@ def test_conv3D(self): input = bp.math.ones((2, 5, 5, 5, 3)) out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 5, 32)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 9ad15938d..269fec441 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- +from unittest import TestCase + +import jax.numpy as jnp +import brainpy.math as bm from absl.testing import absltest from absl.testing import parameterized - import brainpy as bp -import brainpy.math as bm class TestFunction(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index df5293ab9..422f161f1 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,218 +1,223 @@ -from absl.testing import absltest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class TestLinear(parameterized.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bm.random.seed() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - num_out=[20, 10, 5] - ) - def test_Dense1(self, size, num_out): - bm.random.seed() - f = bp.dnn.Linear(10, num_out) - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size[:-1] + (num_out,)) - bm.clear_buffer_memory() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - ) - def test_Identity(self, size): - bm.random.seed() - f = bp.dnn.Identity() - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size) - bm.clear_buffer_memory() - - def test_AllToAll1(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((8, 10)) - y = f(x) - expected = bm.sum(x, axis=1, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((10,)) - y = f(x) - expected = bm.sum(x, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - bm.clear_buffer_memory() - - def test_OneToOne(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((8, 10)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((10,)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - bm.clear_buffer_memory() - - @parameterized.product( - conn=[ - # bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_MaskedLinear(self, conn): - bm.random.seed() - bm.random.DEFAULT.seed(123) - f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - bm.clear_buffer_memory() - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_CSRLinear(self, conn): - bm.random.seed() - f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_EventCSRLinear(self,conn): - bm.random.seed() - f=bp.layers.EventCSRLinear(conn,weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - -if __name__ == '__main__': - absltest.main() +import pytest +from absl.testing import absltest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class TestLinear(parameterized.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bm.random.seed() + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + num_out=[20, 10, 5] + ) + def test_Dense1(self, size, num_out): + bm.random.seed() + f = bp.dnn.Linear(10, num_out) + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size[:-1] + (num_out,)) + bm.clear_buffer_memory() + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + ) + def test_Identity(self, size): + bm.random.seed() + f = bp.dnn.Identity() + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size) + bm.clear_buffer_memory() + + def test_AllToAll1(self): + bm.random.seed() + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((8, 10)) + y = f(x) + expected = bm.sum(x, axis=1, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((10,)) + y = f(x) + expected = bm.sum(x, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + bm.clear_buffer_memory() + + def test_OneToOne(self): + bm.random.seed() + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((8, 10)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((10,)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + # bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_MaskedLinear(self, conn): + bm.random.seed() + bm.random.DEFAULT.seed(123) + f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_CSRLinear(self, conn): + bm.random.seed() + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_EventCSRLinear(self, conn): + bm.random.seed() + f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPHomoLinear(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPHomoLinear(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + +if __name__ == '__main__': + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 3cf923d7b..f0c67da12 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -1,801 +1,807 @@ -from absl.testing import absltest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class Test_Conv(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv2_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv3_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose2d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose3d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - -class TestPool(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MinPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AvgPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.MaxPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - -class Test_Dropout(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Dropout(self, mode): - bp.share.save(fit=False) - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.Dropout(prob=0.2, - mode=mode) - output = layer(input) - - -class Test_function(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Flatten(self, mode): - bm.random.seed() - layer = bp.dnn.Flatten(mode=mode) - input = bm.random.randn(10, 5, 5, 5, 4) - output = layer(input) - - -class Test_linear(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_linear(self, mode): - bm.random.seed() - input = bm.random.randn(10, 9, 8, 7) - layer = bp.dnn.Linear(num_in=7, - num_out=6, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AllToAll(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.AllToAll(num_pre=10, - num_post=20, - weight=0.1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_OneToOne(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.OneToOne(num=10, - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaskedLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_CSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventCSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - -class Test_Normalization(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm1d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm1d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm2d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm2d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm3d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm3d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 7, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_LayerNorm(self, mode): - bm.random.seed() - layer = bp.dnn.LayerNorm(normalized_shape=3, - mode=mode, - elementwise_affine=False - ) - input = bm.random.randn(10, 5, 3) - outout = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_GroupNorm(self, mode): - bm.random.seed() - layer = bp.dnn.GroupNorm(num_groups=2, - num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_InstanceNorm(self, mode): - bm.random.seed() - layer = bp.dnn.InstanceNorm(num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - -if __name__ == '__main__': - absltest.main() +import pytest +from absl.testing import absltest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class Test_Conv(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 3) + layer = bp.dnn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv1d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(50, 3) + layer = bp.dnn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 3) + layer = bp.dnn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv2_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dnn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 50, 3) + layer = bp.dnn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv3_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dnn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 3) + layer = bp.dnn.ConvTranspose1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose1d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 3) + layer = bp.dnn.ConvTranspose1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 3) + layer = bp.dnn.ConvTranspose2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose2d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dnn.ConvTranspose2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 50, 3) + layer = bp.dnn.ConvTranspose3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose3d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dnn.ConvTranspose3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + +class TestPool(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MinPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AvgPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AvgPool1d(kernel_size=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.MaxPool1d(kernel_size=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + +class Test_Dropout(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_Dropout(self, mode): + bp.share.save(fit=False) + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.Dropout(prob=0.2, + mode=mode) + output = layer(input) + + +class Test_function(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_Flatten(self, mode): + bm.random.seed() + layer = bp.dnn.Flatten(mode=mode) + input = bm.random.randn(10, 5, 5, 5, 4) + output = layer(input) + + +class Test_linear(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_linear(self, mode): + bm.random.seed() + input = bm.random.randn(10, 9, 8, 7) + layer = bp.dnn.Linear(num_in=7, + num_out=6, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AllToAll(self, mode): + bm.random.seed() + input = bm.random.randn(10, 10) + layer = bp.dnn.AllToAll(num_pre=10, + num_post=20, + weight=0.1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_OneToOne(self, mode): + bm.random.seed() + input = bm.random.randn(10, 10) + layer = bp.dnn.OneToOne(num=10, + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaskedLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_CSRLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventCSRLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPHomoLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPHomoLinear(num_in=100, + num_out=200, + prob=0.1, + weight=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPUniformLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPUniformLinear(num_in=100, + num_out=200, + prob=0.1, + w_low=-0.01, + w_high=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPNormalLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPNormalLinear(num_in=100, + num_out=200, + prob=0.1, + w_mu=-0.01, + w_sigma=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPHomoLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPHomoLinear(num_in=100, + num_out=200, + prob=0.1, + weight=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPNormalLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPNormalLinear(num_in=100, + num_out=200, + prob=0.1, + w_mu=-0.01, + w_sigma=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPUniformLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPUniformLinear(num_in=100, + num_out=200, + prob=0.1, + w_low=-0.01, + w_high=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + +class Test_Normalization(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm1d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm1d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm2d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm2d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 6, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm3d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm3d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 6, 7, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_LayerNorm(self, mode): + bm.random.seed() + layer = bp.dnn.LayerNorm(normalized_shape=3, + mode=mode, + elementwise_affine=False + ) + input = bm.random.randn(10, 5, 3) + outout = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_GroupNorm(self, mode): + bm.random.seed() + layer = bp.dnn.GroupNorm(num_groups=2, + num_channels=6, + affine=False, + mode=mode + ) + input = bm.random.randn(20, 10, 10, 6) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_InstanceNorm(self, mode): + bm.random.seed() + layer = bp.dnn.InstanceNorm(num_channels=6, + affine=False, + mode=mode + ) + input = bm.random.randn(20, 10, 10, 6) + output = layer(input) + + +if __name__ == '__main__': + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index de2c9765b..fdc5b34e3 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,8 +1,7 @@ -from absl.testing import absltest +import brainpy.math as bm from absl.testing import parameterized - +from absl.testing import absltest import brainpy as bp -import brainpy.math as bm class Test_Normalization(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 5748edd8b..34f8f5cd5 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import absltest from absl.testing import parameterized +from absl.testing import absltest import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index b8884f327..18d9d9dc9 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,120 +1,127 @@ -# -*- coding: utf-8 -*- - - -import numpy as np -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class Test_STDP(parameterized.TestCase): - - @parameterized.product( - comm_method=['dense', 'csr', 'masked_linear', 'all2all', 'one2one'], - delay=[None, 0., 2.], - syn_model=['exp', 'dual_exp', 'ampa'], - out_model=['cuba', 'coba', 'mg'] - ) - def test_STDP(self, comm_method, delay, syn_model, out_model): - bm.random.seed() - - class STDPNet(bp.DynamicalSystem): - def __init__(self, num_pre, num_post): - super().__init__() - self.pre = bp.dyn.LifRef(num_pre) - self.post = bp.dyn.LifRef(num_post) - - if comm_method == 'all2all': - comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) - elif comm_method == 'csr': - if syn_model == 'exp': - comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - else: - comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - elif comm_method == 'masked_linear': - comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - elif comm_method == 'dense': - comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) - elif comm_method == 'one2one': - comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) - else: - raise ValueError - - if syn_model == 'exp': - syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) - elif syn_model == 'dual_exp': - syn = bp.dyn.DualExpon.desc(self.post.varshape) - elif syn_model == 'dual_exp_v2': - syn = bp.dyn.DualExponV2.desc(self.post.varshape) - elif syn_model == 'ampa': - syn = bp.dyn.AMPA.desc(self.post.varshape) - else: - raise ValueError - - if out_model == 'cuba': - out = bp.dyn.CUBA.desc() - elif out_model == 'coba': - out = bp.dyn.COBA.desc(E=0.) - elif out_model == 'mg': - out = bp.dyn.MgBlock.desc(E=0.) - else: - raise ValueError - - self.syn = bp.dyn.STDP_Song2000( - pre=self.pre, - delay=delay, - comm=comm, - syn=syn, - out=out, - post=self.post, - tau_s=16.8, - tau_t=33.7, - A1=0.96, - A2=0.53, - W_min=0., - W_max=1. - ) - - def update(self, I_pre, I_post): - self.syn() - self.pre(I_pre) - self.post(I_post) - conductance = self.syn.refs['syn'].g - Apre = self.syn.refs['pre_trace'].g - Apost = self.syn.refs['post_trace'].g - current = self.post.sum_current_inputs(self.post.V) - if comm_method == 'dense': - w = self.syn.comm.W.flatten() - else: - w = self.syn.comm.weight.flatten() - return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w - - duration = 300. - I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) - I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) - - net = STDPNet(1, 1) - - def run(i, I_pre, I_post): - pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) - return pre_spike, post_spike, g, Apre, Apost, current, W - - indices = np.arange(int(duration / bm.dt)) - pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) - - # import matplotlib.pyplot as plt - # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) - # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) - # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) - # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) - # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) - # plt.show() - - bm.clear_buffer_memory() - +# -*- coding: utf-8 -*- + +import numpy as np +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +bm.set_platform('cpu') + + +class Test_STDP(parameterized.TestCase): + + @parameterized.product( + comm_method=['csr', 'dense', 'masked_linear', 'all2all', 'one2one'], + delay=[None, 0., 2.], + syn_model=['exp', 'dual_exp', 'ampa'], + out_model=['cuba', 'coba', 'mg'] + ) + def test_STDP(self, comm_method, delay, syn_model, out_model): + bm.random.seed() + + class STDPNet(bp.DynamicalSystem): + def __init__(self, num_pre, num_post): + super().__init__() + self.pre = bp.dyn.LifRef(num_pre) + self.post = bp.dyn.LifRef(num_post) + + if comm_method == 'all2all': + comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'csr': + if syn_model == 'exp': + comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + else: + comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'masked_linear': + comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'dense': + comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'one2one': + comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) + else: + raise ValueError + + if syn_model == 'exp': + syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) + elif syn_model == 'dual_exp': + syn = bp.dyn.DualExpon.desc(self.post.varshape) + elif syn_model == 'dual_exp_v2': + syn = bp.dyn.DualExponV2.desc(self.post.varshape) + elif syn_model == 'ampa': + syn = bp.dyn.AMPA.desc(self.post.varshape) + else: + raise ValueError + + if out_model == 'cuba': + out = bp.dyn.CUBA.desc() + elif out_model == 'coba': + out = bp.dyn.COBA.desc(E=0.) + elif out_model == 'mg': + out = bp.dyn.MgBlock.desc(E=0.) + else: + raise ValueError + + self.syn = bp.dyn.STDP_Song2000( + pre=self.pre, + delay=delay, + comm=comm, + syn=syn, + out=out, + post=self.post, + tau_s=16.8, + tau_t=33.7, + A1=0.96, + A2=0.53, + W_min=0., + W_max=1. + ) + + def update(self, I_pre, I_post): + self.syn() + self.pre(I_pre) + self.post(I_post) + conductance = self.syn.refs['syn'].g + Apre = self.syn.refs['pre_trace'].g + Apost = self.syn.refs['post_trace'].g + current = self.post.sum_current_inputs(self.post.V) + if comm_method == 'dense': + w = self.syn.comm.W.flatten() + else: + w = self.syn.comm.weight.flatten() + return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w + + duration = 300. + I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, + duration - 255]) + I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, + duration - 250]) + + net = STDPNet(1, 1) + + def run(i, I_pre, I_post): + pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) + return pre_spike, post_spike, g, Apre, Apost, current, W + + indices = np.arange(int(duration / bm.dt)) + pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) + + # import matplotlib.pyplot as plt + # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) + # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) + # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) + # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) + # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) + # plt.show() + + bm.clear_buffer_memory() diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py index 90500a26f..eec2c9459 100644 --- a/brainpy/_src/dyn/projections/tests/test_aligns.py +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -1,439 +1,444 @@ -import matplotlib.pyplot as plt -import numpy as np - -import brainpy as bp -import brainpy.math as bm - -neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - - -def test_ProjAlignPreMg1(): - class EICOBA_PreAlign(bp.DynamicalSystem): - def __init__(self, scale=1., inp=20., delay=None): - super().__init__() - - self.inp = inp - self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) - - prob = 80 / (4000 * scale) - - self.E2I = bp.dyn.FullProjAlignPreSDMg( - pre=self.E, - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), - out=bp.dyn.COBA(E=0.), - post=self.I, - ) - self.E2E = bp.dyn.FullProjAlignPreSDMg( - pre=self.E, - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), - out=bp.dyn.COBA(E=0.), - post=self.E, - ) - self.I2E = bp.dyn.FullProjAlignPreSDMg( - pre=self.I, - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E, - ) - self.I2I = bp.dyn.FullProjAlignPreSDMg( - pre=self.I, - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I, - ) - - def update(self): - self.E2I() - self.I2I() - self.I2E() - self.E2E() - self.E(self.inp) - self.I(self.inp) - return self.E.spike.value - - net = EICOBA_PreAlign(0.5) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PreAlign(0.5, delay=1.) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - plt.close() - bm.clear_buffer_memory() - - -def test_ProjAlignPostMg2(): - class EICOBA_PostAlign(bp.DynamicalSystem): - def __init__(self, scale, inp=20., ltc=True, delay=None): - super().__init__() - self.inp = inp - - if ltc: - self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) - else: - self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) - - prob = 80 / (4000 * scale) - - self.E2E = bp.dyn.FullProjAlignPostMg( - pre=self.E, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.E, - ) - self.E2I = bp.dyn.FullProjAlignPostMg( - pre=self.E, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), - syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.I, - ) - self.I2E = bp.dyn.FullProjAlignPostMg( - pre=self.I, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), - syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.E, - ) - self.I2I = bp.dyn.FullProjAlignPostMg( - pre=self.I, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.I, - ) - - def update(self): - self.E2I() - self.I2I() - self.I2E() - self.E2E() - self.E(self.inp) - self.I(self.inp) - return self.E.spike.value - - net = EICOBA_PostAlign(0.5) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PostAlign(0.5, delay=1.) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PostAlign(0.5, ltc=False) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - plt.close() - bm.clear_buffer_memory() - - -def test_ProjAlignPost1(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): - super().__init__() - num = int(4000 * scale) - self.num_exc = int(3200 * scale) - self.num_inh = num - self.num_exc - prob = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), - syn=bp.dyn.Expon(size=num, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), - syn=bp.dyn.Expon(size=num, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:self.num_exc]) - self.I(spk[self.num_exc:]) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet(0.5) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPost2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale, delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (ne + ni) - - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet(0.5, delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(0.5, delay=None) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_VanillaProj(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=0.5): - super().__init__() - num = int(4000 * scale) - self.ne = int(3200 * scale) - self.ni = num - self.ne - p = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) - self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) - self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(self.syn1(spk[:self.ne])) - self.I(self.syn2(spk[self.ne:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPreMg1_v2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1., delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (4000 * scale) - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPreMg2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1., delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (4000 * scale) - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet(scale=0.2, delay=None) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(scale=0.2, delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_vanalla_proj_v2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): - super().__init__() - num = int(4000 * scale) - self.ne = int(3200 * scale) - self.ni = num - self.ne - p = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 1.)) - self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) - self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) - self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) - self.E = bp.dyn.VanillaProj( - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N - ) - self.I = bp.dyn.VanillaProj( - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N - ) - - def update(self, input): - spk = self.delay.at('delay') - self.E(self.syn1(spk[:self.ne])) - self.I(self.syn2(spk[self.ne:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) - bp.visualize.raster_plot(indices, spks, show=True) - plt.close() - bm.clear_buffer_memory() - +import pytest +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + + +def test_ProjAlignPreMg1(): + class EICOBA_PreAlign(bp.DynamicalSystem): + def __init__(self, scale=1., inp=20., delay=None): + super().__init__() + + self.inp = inp + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2I = bp.dyn.FullProjAlignPreSDMg( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.I, + ) + self.E2E = bp.dyn.FullProjAlignPreSDMg( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.E, + ) + self.I2E = bp.dyn.FullProjAlignPreSDMg( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.FullProjAlignPreSDMg( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PreAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PreAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + plt.close() + bm.clear_buffer_memory() + + +def test_ProjAlignPostMg2(): + class EICOBA_PostAlign(bp.DynamicalSystem): + def __init__(self, scale, inp=20., ltc=True, delay=None): + super().__init__() + self.inp = inp + + if ltc: + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + else: + self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2E = bp.dyn.FullProjAlignPostMg( + pre=self.E, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E, + ) + self.E2I = bp.dyn.FullProjAlignPostMg( + pre=self.E, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I, + ) + self.I2E = bp.dyn.FullProjAlignPostMg( + pre=self.I, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.FullProjAlignPostMg( + pre=self.I, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PostAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PostAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PostAlign(0.5, ltc=False) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + plt.close() + bm.clear_buffer_memory() + + +def test_ProjAlignPost1(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.num_exc = int(3200 * scale) + self.num_inh = num - self.num_exc + prob = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), + syn=bp.dyn.Expon(size=num, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), + syn=bp.dyn.Expon(size=num, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:self.num_exc]) + self.I(spk[self.num_exc:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet(0.5) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPost2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale, delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (ne + ni) + + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(0.5, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(0.5, delay=None) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_VanillaProj(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=0.5): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPreMg1_v2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1., delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPreMg2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1., delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(scale=0.2, delay=None) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(scale=0.2, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_vanalla_proj_v2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 1.)) + self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N + ) + self.I = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N + ) + + def update(self, input): + spk = self.delay.at('delay') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) + bp.visualize.raster_plot(indices, spks, show=True) + plt.close() + bm.clear_buffer_memory() diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index badb60832..c3936f685 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -1,126 +1,130 @@ -# -*- coding: utf-8 -*- - - -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm -from brainpy._src.dynold.synapses import abstract_models - - -class Test_Abstract_Synapse(parameterized.TestCase): - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_all2all_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(bp.synapses, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_one2one_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - comp_type=['sparse', 'dense'], - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_sparse_synapse(self, comp_type, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - post_ref_key=[None, 'refractory'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_delta_synapse(self, post_ref_key, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5, ref_var=True) - post_neu = bp.neurons.LIF(3, ref_var=True) - syn = bp.synapses.Delta(pre_neu, post_neu, - conn=bp.conn.All2All(), - post_ref_key=post_ref_key, - stp=stp, ) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - pre_expected_shape = (100, 5) - post_expected_shape = (100, 3) - if isinstance(mode, bm.BatchingMode): - pre_expected_shape = (mode.batch_size,) + pre_expected_shape - post_expected_shape = (mode.batch_size,) + post_expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dynold.synapses import abstract_models +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class Test_Abstract_Synapse(parameterized.TestCase): + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_all2all_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(bp.synapses, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_one2one_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + comp_type=['sparse', 'dense'], + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_sparse_synapse(self, comp_type, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + post_ref_key=[None, 'refractory'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_delta_synapse(self, post_ref_key, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5, ref_var=True) + post_neu = bp.neurons.LIF(3, ref_var=True) + syn = bp.synapses.Delta(pre_neu, post_neu, + conn=bp.conn.All2All(), + post_ref_key=post_ref_key, + stp=stp, ) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + pre_expected_shape = (100, 5) + post_expected_shape = (100, 3) + if isinstance(mode, bm.BatchingMode): + pre_expected_shape = (mode.batch_size,) + pre_expected_shape + post_expected_shape = (mode.batch_size,) + post_expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) + bm.clear_buffer_memory() diff --git a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py index 395868092..01a315261 100644 --- a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py @@ -1,103 +1,108 @@ -# -*- coding: utf-8 -*- - - -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -biological_models = [ - bp.synapses.AMPA, - bp.synapses.GABAa, - bp.synapses.BioNMDA, -] - - -class Test_Biological_Synapse(parameterized.TestCase): - @parameterized.product( - synapse=biological_models, - delay_step=[None, 5, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_all2all_synapse(self, synapse, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - synapse=biological_models, - delay_step=[None, 10, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_one2one_synapse(self, synapse, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - synapse=biological_models, - comp_method=['sparse', 'dense'], - delay_step=[None, 10, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(10) - post_neu = bp.neurons.LIF(10) - syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), - comp_method=comp_method, delay_step=delay_step, - stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 10) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +biological_models = [ + bp.synapses.AMPA, + bp.synapses.GABAa, + bp.synapses.BioNMDA, +] + + +class Test_Biological_Synapse(parameterized.TestCase): + @parameterized.product( + synapse=biological_models, + delay_step=[None, 5, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_all2all_synapse(self, synapse, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + synapse=biological_models, + delay_step=[None, 10, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_one2one_synapse(self, synapse, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + synapse=biological_models, + comp_method=['sparse', 'dense'], + delay_step=[None, 10, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(10) + post_neu = bp.neurons.LIF(10) + syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), + comp_method=comp_method, delay_step=delay_step, + stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 10) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 19aca92cf..6ebe9dc26 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -7,7 +7,7 @@ __all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_'] -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) # Default computation mode. mode = NonBatchingMode() @@ -24,15 +24,19 @@ # '''Default integer data type.''' int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32 -# '''Default integer data type in Taichi.''' -ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 - # '''Default float data type.''' float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32 -# '''Default float data type in Taichi.''' -ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 - # '''Default complex data type.''' complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 +if ti is not None: + # '''Default integer data type in Taichi.''' + ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 + + # '''Default float data type in Taichi.''' + ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 + +else: + ti_int = None + ti_float = None diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 676e4286b..eb8e27c8f 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -11,7 +11,7 @@ from brainpy import check from brainpy.check import is_float, is_integer, jit_error from brainpy.errors import UnsupportedError -from .compat_numpy import broadcast_to, expand_dims, concatenate +from .compat_numpy import vstack, broadcast_to from .environment import get_dt, get_float from .interoperability import as_jax from .ndarray import ndarray, Array @@ -392,7 +392,6 @@ def reset( dtype=delay_target.dtype), batch_axis=batch_axis) else: - self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) @@ -473,7 +472,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) + self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) else: self.data[:] = value diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 1c8b98a3b..668f837c0 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -18,7 +18,7 @@ from . import defaults from brainpy._src.dependency_check import import_taichi -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ # context manage for environment setting @@ -416,13 +416,16 @@ def set_float(dtype: type): """ if dtype in [jnp.float16, 'float16', 'f16']: defaults.__dict__['float_'] = jnp.float16 - defaults.__dict__['ti_float'] = ti.float16 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float16 elif dtype in [jnp.float32, 'float32', 'f32']: defaults.__dict__['float_'] = jnp.float32 - defaults.__dict__['ti_float'] = ti.float32 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float32 elif dtype in [jnp.float64, 'float64', 'f64']: defaults.__dict__['float_'] = jnp.float64 - defaults.__dict__['ti_float'] = ti.float64 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float64 else: raise NotImplementedError @@ -448,16 +451,20 @@ def set_int(dtype: type): """ if dtype in [jnp.int8, 'int8', 'i8']: defaults.__dict__['int_'] = jnp.int8 - defaults.__dict__['ti_int'] = ti.int8 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int8 elif dtype in [jnp.int16, 'int16', 'i16']: defaults.__dict__['int_'] = jnp.int16 - defaults.__dict__['ti_int'] = ti.int16 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int16 elif dtype in [jnp.int32, 'int32', 'i32']: defaults.__dict__['int_'] = jnp.int32 - defaults.__dict__['ti_int'] = ti.int32 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int32 elif dtype in [jnp.int64, 'int64', 'i64']: defaults.__dict__['int_'] = jnp.int64 - defaults.__dict__['ti_int'] = ti.int64 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int64 else: raise NotImplementedError diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 631129558..bdd3102a3 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,4 +1,2 @@ - -from ._info_collection import * from ._csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 6e03be463..6b7f7da02 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -10,34 +10,25 @@ """ -from functools import partial from typing import Union, Tuple import jax import jax.numpy as jnp -import numba import numpy as np -from jax.core import ShapedArray, Primitive -from jax.interpreters import ad, xla -from jax.lib import xla_client +from jax.interpreters import ad -from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, - XLACustomOp) -from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv +from brainpy._src.math.op_register import XLACustomOp from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy.errors import GPUOperatorNotFound +from brainpy.errors import PackageMissingError __all__ = [ 'csrmv' ] -ti = import_taichi() - +ti = import_taichi(error_if_not_found=False) def csrmv( data: Union[float, jax.Array], @@ -53,577 +44,6 @@ def csrmv( This function supports JAX transformations, including `jit()`, `grad()`, `vmap()` and `pmap()`. - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - - -### BRAINPYLIB ### - -def csrmv_brainpylib( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError('indices should be a 1D vector with int32 or int64 type.') - if indptr.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError('indptr should be a 1D vector with int32 or int64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # computing - return event_csr_matvec_p.bind(data, indices, indptr, events, shape=shape, transpose=transpose) - - -# ---------------------------------------------------------- -# event csr matvec -# ---------------------------------------------------------- - -# operator for `event_csr_matvec` batching rule -# -------- - -def _batch_event_csr_matvec_abstract( - values, indices, indptr, events, *, batch_size, shape, transpose=False -): - return ShapedArray(dtype=values.dtype, shape=(batch_size, shape[1] if transpose else shape[0])) - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_csr_matvec_transpose_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, batch_size, shape, _ = ins - batch_size = batch_size[()] - event_batch_dim = events.shape[0] - indices_batch_dim = indices.shape[0] - indptr_batch_dim = indptr.shape[0] - values_batch_dim = values.shape[0] - - if values.shape[1] == 1: # homogeneous value - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - values_bi = bi % values_batch_dim - for row_i in range(shape[0]): - if events[event_bi, row_i]: - value = values[values_bi, 0] - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - res_val[bi, col_i] += value - - else: # heterogeneous values - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - for row_i in range(shape[0]): - if events[event_bi, row_i]: - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - res_val[bi, col_i] += values[value_bi, j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_csr_matvec_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, batch_size, shape, transpose = ins - batch_size = batch_size[()] - event_batch_dim = events.shape[0] - indices_batch_dim = indices.shape[0] - indptr_batch_dim = indptr.shape[0] - values_batch_dim = values.shape[0] - - if values.shape[1] == 1: # homogeneous value - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - value = values[value_bi, 0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - if events[event_bi, col_i]: - r += value - res_val[bi, row_i] = r - - else: # heterogeneous values - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - if events[event_bi, col_i]: - r += values[value_bi, j] - res_val[bi, row_i] = r - - -def _batch_event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, - batch_size, shape, transpose): - inputs = (values, indices, indptr, events) - description = dict(batch_size=batch_size, shape=shape, transpose=transpose) - if transpose: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _batch_event_csr_matvec_transpose_numba_imp, - _batch_event_csr_matvec_abstract, - False, - inputs=inputs, - description=description - ) - else: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _batch_event_csr_matvec_numba_imp, - _batch_event_csr_matvec_abstract, - False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, - name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _batch_event_csr_matvec_gpu_translation(c, values, indices, indptr, events, *, - batch_size, shape, transpose): - pass - - -def _batch_event_csr_matvec_jvp_values(values_dot, values, indices, indptr, events, *, - batch_size, shape, transpose): - return event_csr_matvec_batching_p.bind(values_dot, indices, indptr, events, - batch_size=batch_size, shape=shape, transpose=transpose) - - -def _batch_csr_matvec(values, indices, indptr, vectors, *, shape, transpose): - f = jax.vmap(partial(normal_csrmv, shape=shape, transpose=transpose), - in_axes=(0 if values.shape[0] > 1 else None, - 0 if indices.shape[0] > 1 else None, - 0 if indptr.shape[0] > 1 else None, - 0 if vectors.shape[0] > 1 else None)) - return f(values if values.shape[0] > 1 else values[0], - indices if indices.shape[0] > 1 else indices[0], - indptr if indptr.shape[0] > 1 else indptr[0], - vectors if vectors.shape[0] > 1 else vectors[0]) - - -def _batch_event_csr_matvec_jvp_events(events_dot, values, indices, indptr, events, *, - batch_size, shape, transpose): - return _batch_csr_matvec(values, indices, indptr, events_dot, - shape=shape, transpose=transpose) - - -def _batch_event_csr_matvec_transpose(ct, values, indices, indptr, events, *, - batch_size, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(events): - ct_events = ( - ad.Zero(events.aval) if type(ct) is ad.Zero else - _batch_csr_matvec(ct, indices, indptr, values, - shape=shape, transpose=not transpose) - ) - return values, indices, indptr, ct_events - else: - if values.aval.shape[1] == 1: # scalar - temp = event_csr_matvec_batching_p.bind(jnp.ones((1, 1)), indices, indptr, events, - batch_size=batch_size, shape=shape, - transpose=transpose) - ct_values = jax.vmap(jnp.inner)(ct, temp) - else: # heterogeneous values - if type(ct) is ad.Zero: - ct_values = ad.Zero(values.aval) - else: - - def _f(ct, indices, indptr, events, *, transpose): - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[col] if transpose else events[col] * ct[row] - return ct_values - - f = jax.vmap(partial(_f, transpose=transpose), - in_axes=(0, - 0 if indices.shape[0] > 1 else None, - 0 if indptr.shape[0] > 1 else None, - 0 if events.shape[0] > 1 else None)) - ct_values = f(ct, - indices if indices.shape[0] > 1 else indices[0], - indptr if indptr.shape[0] > 1 else indptr[0], - events if events.shape[0] > 1 else events[0]) - return ct_values, indices, indptr, events - - -event_csr_matvec_batching_p = Primitive('event_csr_matvec_batching') -event_csr_matvec_batching_p.def_abstract_eval(_batch_event_csr_matvec_abstract) -event_csr_matvec_batching_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_batching_p)) -# xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation -ad.defjvp(event_csr_matvec_batching_p, _batch_event_csr_matvec_jvp_values, - None, None, _batch_event_csr_matvec_jvp_events) -ad.primitive_transposes[event_csr_matvec_batching_p] = _batch_event_csr_matvec_transpose - - -# operator for `event_csr_matvec` # -# ------------------------------- # - - -def _event_csr_matvec_abstract(values, indices, indptr, events, *, shape, transpose=False): - return ShapedArray(dtype=values.dtype, shape=(shape[1] if transpose else shape[0],)) - - -@numba.njit(fastmath=True) -def _event_csr_matvec_transpose_numba_imp1_bool(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - if values.shape[0] > 1: # heter - for row_i, event in enumerate(events): - if event: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values[j] - - else: # homo - values = values[0] - for row_i, event in enumerate(events): - if event: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values - - -@numba.njit(fastmath=True) -def _event_csr_matvec_transpose_numba_imp2(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - if values.shape[0] > 1: # heter - for row_i, event in enumerate(events): - if event > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values[j] - - else: # homo - values = values[0] - for row_i, event in enumerate(events): - if event > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _event_csr_matvec_numba_imp1_bool(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - - if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i]: - r += values[j] - res_val[row_i] = r - - else: # homo - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i]: - r += values - res_val[row_i] = r - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _event_csr_matvec_numba_imp2(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - - if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i] > 0.: - r += values[j] - res_val[row_i] = r - - else: # homo - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i] > 0.: - r += values - res_val[row_i] = r - - -def _event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, shape, transpose): - inputs = (values, indices, indptr, events) - event_type = c.get_shape(events) - description = dict(shape=shape, transpose=transpose) - if transpose: - if event_type.element_type() == jnp.bool_: - imp = _event_csr_matvec_transpose_numba_imp1_bool - else: - imp = _event_csr_matvec_transpose_numba_imp2 - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - imp, - abs_eval_fn=_event_csr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - else: - if event_type.element_type() == jnp.bool_: - imp = _event_csr_matvec_numba_imp1_bool - else: - imp = _event_csr_matvec_numba_imp2 - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - imp, - abs_eval_fn=_event_csr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _event_csr_matvec_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_csr_matvec_p.name) - - # shape checking - data_shape = c.get_shape(data) - indices_shape = c.get_shape(indices) - indptr_shape = c.get_shape(indptr) - vec_shape = c.get_shape(vector) - if data_shape.element_type() == jnp.float32: - ftype = b'_float' - elif data_shape.element_type() == jnp.float64: - ftype = b'_double' - else: - raise ValueError - assert indices_shape.element_type() == indptr_shape.element_type() - if indices_shape.element_type() == jnp.int32: - itype = b'_int' - elif indices_shape.element_type() == jnp.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'_homo' if data_shape.dimensions() == (1,) else b'_heter' - tran_type = b'_transpose' if transpose else b'' - if vec_shape.element_type() == jnp.bool_: - vec_type = b'_bool' - else: - assert vec_shape.element_type() == data_shape.element_type() - vec_type = b'' - - # opaque - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - - # call - return xla_client.ops.CustomCallWithLayout( - c, - b'event_csrmv' + data_name + ftype + itype + vec_type + tran_type, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), - (shape[1] if transpose else shape[0],), - (0,)), - opaque=opaque, - ) - - -def _event_csr_matvec_batching_rule(args, axes, *, shape, transpose): - batch_size = 0 - args_processed = [] - for arg, axis in zip(args, axes): - if axis is None: - arg = jnp.expand_dims(jnp.atleast_1d(arg), 0) - else: - batch_size = arg.shape[axis] - if axis > 0: - arg = jnp.moveaxis(arg, axis, 0) - args_processed.append(arg) - - r = event_csr_matvec_batching_p.bind(*args_processed, - batch_size=batch_size, - shape=shape, - transpose=transpose) - return r, 0 - - -def _event_csr_matvec_jvp_values_brainpylib(values_dot, values, indices, indptr, events, *, shape, transpose): - return normal_csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose) - - -def _event_csr_matvec_jvp_events_brainpylib(events_dot, values, indices, indptr, events, *, shape, transpose): - return normal_csrmv(values, indices, indptr, events_dot, shape=shape, transpose=transpose) - - -def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv(values, indices, indptr, ct, shape=shape, transpose=not transpose) - return values, indices, indptr, (ad.Zero(events) if type(ct) is ad.Zero else ct_events) - else: - if type(ct) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = csrmv_brainpylib(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose) - ct_values = jnp.inner(ct, ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[col] if transpose else events[col] * ct[row] - return ct_values, indices, indptr, events - - -event_csr_matvec_p = Primitive('event_csr_matvec') -event_csr_matvec_p.def_abstract_eval(_event_csr_matvec_abstract) -event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p)) -# xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation -# xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation -ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, - _event_csr_matvec_jvp_events_brainpylib) -ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib -register_general_batching(event_csr_matvec_p) - - -# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule - - -### TAICHI ### - -def csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - Parameters ---------- data: ndarray, float @@ -691,298 +111,6 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] -# ------------- -# CPU operators -# ------------- - -# 1. The benchmarking shows that the performance of the following transpose -# kernels is maximized when using serialized mode -# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable -# arguments, we have to define each kernel separately when the -# non-differentiable/non-jittable arguments are different. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += values[j] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - -# 1. GPU kernels are different from the CPU ones, since the GPU kernels need -# to use warp-level parallelism to achieve the best performance. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -# TODO -# It is important to note that the following warp-based kernels -# should be improved, since the atomic_add for each thread is not -# very efficient. Instead, the warp-level reduction primitive -# should be used. -# see ``warp_reduce_sum()`` function in tifunc.py. -# However, currently Taichi does not support general warp-level primitives. - - -@ti.kernel -def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - def raw_csrmv_taichi( data: Union[float, jax.Array], indices: jax.Array, @@ -992,6 +120,9 @@ def raw_csrmv_taichi( shape: Tuple[int, int], transpose: bool = False ): + if ti is None: + raise PackageMissingError.by_purpose(name='taichi==1.7.0', purpose='customized operators') + if transpose: if events.dtype == jnp.bool_: if data.shape[0] == 1: @@ -1025,65 +156,361 @@ def raw_csrmv_taichi( shape=shape) -def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) +if ti is not None: + + # ------------- + # CPU operators + # ------------- + + # 1. The benchmarking shows that the performance of the following transpose + # kernels is maximized when using serialized mode + # 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable + # arguments, we have to define each kernel separately when the + # non-differentiable/non-jittable arguments are different. + + @ti.kernel + def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + + @ti.kernel + def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + + @ti.kernel + def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + + @ti.kernel + def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + + @ti.kernel + def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += value + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += values[j] + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += value + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += values[j] + out[row_i] = r + + + # ------------- + # GPU operators + # ------------- + + # 1. GPU kernels are different from the CPU ones, since the GPU kernels need + # to use warp-level parallelism to achieve the best performance. + + @ti.kernel + def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + + @ti.kernel + def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + + # TODO + # It is important to note that the following warp-based kernels + # should be improved, since the atomic_add for each thread is not + # very efficient. Instead, the warp-level reduction primitive + # should be used. + # see ``warp_reduce_sum()`` function in tifunc.py. + # However, currently Taichi does not support general warp-level primitives. + + @ti.kernel + def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + + @ti.kernel + def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + + @ti.kernel + def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive -def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) -def _event_csr_matvec_transpose_taichi( - ct, values, indices, indptr, events, *, outs, transpose, shape -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) + def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + + + def _event_csr_matvec_transpose_taichi( + ct, values, indices, indptr, events, *, outs, transpose, shape + ): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(events): + ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] + return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) else: - if values.aval.shape[0] == 1: # scalar - ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events + if type(ct[0]) is ad.Zero: + ct_values = ad.Zero(values) + else: + if values.aval.shape[0] == 1: # scalar + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = jnp.inner(ct[0], ct_values) + else: # heterogeneous values + row, col = csr_to_coo(indices, indptr) + ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] + return ct_values, indices, indptr, events -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) - prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) - return prim + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) + prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) + return prim -# transpose bool homo -_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, - _event_csr_matvec_transpose_bool_homo_gpu) + # transpose bool homo + _event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, + _event_csr_matvec_transpose_bool_homo_gpu) -# transpose homo -_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu) + # transpose homo + _event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, + _event_csr_matvec_transpose_homo_gpu) -# not transpose bool homo -_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu) + # not transpose bool homo + _event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, + _event_csr_matvec_bool_homo_gpu) -# not transpose homo -_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu) + # not transpose homo + _event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, + _event_csr_matvec_homo_gpu) -# transpose bool heter -_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, - _event_csr_matvec_transpose_bool_heter_gpu) + # transpose bool heter + _event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, + _event_csr_matvec_transpose_bool_heter_gpu) -# transpose heter -_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, - _event_csr_matvec_transpose_heter_gpu) + # transpose heter + _event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, + _event_csr_matvec_transpose_heter_gpu) -# not transpose bool heter -_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu) + # not transpose bool heter + _event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, + _event_csr_matvec_bool_heter_gpu) -# not transpose heter -_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) + # not transpose heter + _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, + _event_csr_matvec_heter_gpu) diff --git a/brainpy/_src/math/event/_info_collection.py b/brainpy/_src/math/event/_info_collection.py deleted file mode 100644 index 7bb043e3e..000000000 --- a/brainpy/_src/math/event/_info_collection.py +++ /dev/null @@ -1,198 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Tuple, Union - -import jax -import numba -from jax import dtypes, numpy as jnp -from jax.core import ShapedArray -from jax.lib import xla_client - -from brainpy._src.dependency_check import import_brainpylib_gpu_ops -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register.base import XLACustomOp -from brainpy.errors import GPUOperatorNotFound - -ti = import_taichi() - -__all__ = [ - 'info' -] - - -def info(events: Union[Array, jax.Array]) -> Tuple[jax.Array, jax.Array]: - """Collect event information, including event indices, and event number. - - This function supports JAX transformations, including `jit()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - events: jax.Array - The events. - - Returns - ------- - res: tuple - A tuple with two elements, denoting the event indices and the event number. - """ - events = as_jax(events) - if events.ndim != 1: - raise TypeError('Only support 1D boolean vector.') - return event_info_p(events) - - -def _batch_event_info_abstract(events): - assert events.ndim == 2 - # assert events.dtype == jnp.bool_ - event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) - event_num = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=(events.shape[0],)) - return event_ids, event_num - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_info(outs, ins): - event_ids, event_num = outs - event_num.fill(0) - event_ids.fill(-1) - events = ins - for batch_idx in range(event_ids.shape[0]): - num = 0 - for i in range(event_ids.shape[1]): - if events[batch_idx, i]: - event_ids[batch_idx, num] = i - num += 1 - event_num[batch_idx] = num - - -@ti.kernel -def _batch_event_info_taichi(events: ti.types.ndarray(ndim=2), - event_ids: ti.types.ndarray(ndim=2), - event_num: ti.types.ndarray(ndim=1)): - for i, j in ti.grouped(ti.ndrange(event_ids.shape)): - event_ids[i, j] = -1 - for batch_idx in range(event_ids.shape[0]): - num = 0 - for i in range(event_ids.shape[1]): - if events[batch_idx, i]: - event_ids[batch_idx, num] = i - num += 1 - event_num[batch_idx] = num - - -def _batch_event_info_batching_rule(args, axes): - arg = jnp.moveaxis(args[0], axes[0], 0) - shape = arg.shape - arg = jnp.reshape(arg, (shape[0] * shape[1], shape[2])) - event_ids, event_num = batch_event_info_p(arg) - return ((jnp.reshape(event_ids, shape), jnp.reshape(event_num, shape[:2])), - (0, 0)) - - -def _event_info_gpu_translation(c, events): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_info_p.name) - - e_shape = c.get_shape(events).dimensions() - e_type = c.get_shape(events).element_type() - if len(e_shape) == 1: - event_size = e_shape[0] - batch_size = 1 - event_ids_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (event_size,), - (0,)) - else: - batch_size, event_size = e_shape - event_ids_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (batch_size, event_size), - (1, 0)) - event_num_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (batch_size,), - (0,)) - opaque = gpu_ops.build_nonzero_descriptor(event_size, batch_size) - - if e_type == jnp.bool_: - type_name = b'_bool' - elif e_type == jnp.int32: - type_name = b'_int' - elif e_type == jnp.int64: - type_name = b'_long' - elif e_type == jnp.float32: - type_name = b'_float' - elif e_type == jnp.float64: - type_name = b'_double' - else: - raise ValueError - - return xla_client.ops.CustomCallWithLayout( - c, - b'nonzero' + type_name, - operands=(events,), - operand_shapes_with_layout=(c.get_shape(events),), - shape_with_layout=xla_client.Shape.tuple_shape((event_ids_shape, event_num_shape)), - opaque=opaque, - ) - - -batch_event_info_p = XLACustomOp( - name='batched_event_info', - cpu_kernel=_batch_event_info_taichi, - gpu_kernel=_batch_event_info_taichi, - outs=_batch_event_info_abstract, -) -batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule) - - -def _event_info_abstract(events, **kwargs): - assert events.ndim == 1 - # assert events.dtype == jnp.bool_ - event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) - event_num = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=(1,)) - return event_ids, event_num - - -# TODO: first parallel evaluate the sub-sections, then serially event the sub-results. -@numba.njit(fastmath=True) -def _event_info(outs, ins): - event_ids, event_num = outs - event_num.fill(0) - event_ids.fill(-1) - events = ins - num = 0 - for i in range(event_ids.shape[0]): - if events[i]: - event_ids[num] = i - num += 1 - event_num[0] = num - - -@ti.kernel -def _event_info_taichi(events: ti.types.ndarray(ndim=1), - event_ids: ti.types.ndarray(ndim=1), - event_num: ti.types.ndarray(ndim=1)): - for i in range(event_ids.shape[0]): - event_ids[i] = -1 - num = 0 - for i in range(event_ids.shape[0]): - if events[i]: - event_ids[num] = i - num += 1 - event_num[0] = num - - -def _event_info_batching_rule(args, axes): - arg = jnp.moveaxis(args[0], axes[0], 0) - return (batch_event_info_p(arg), (0, 0)) - - -event_info_p = XLACustomOp( - name='event_info', - cpu_kernel=_event_info_taichi, - gpu_kernel=_event_info_taichi, - outs=_event_info_abstract, - # gpu_func_translation=_event_info_gpu_translation, -) -event_info_p.def_batching_rule(_event_info_batching_rule) diff --git a/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py b/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py deleted file mode 100644 index 74cc6b7f9..000000000 --- a/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py +++ /dev/null @@ -1,275 +0,0 @@ -from time import time - -from jax import jit, vmap, numpy as jnp - -import brainpy.math as bm - - -def compare_argsort_and_sum(platform='cpu'): - """ - CPU - --- - - shape = (100, 10000) - brainpylib 0.1872694492340088 s - JAX argsort + sum 5.297466516494751 s - - shape = (100, 100000) - brainpylib 2.333505153656006 s - JAX argsort + sum 65.20281910896301 s - - shape = (1000, 10000) - brainpylib 2.0739688873291016 s - JAX argsort + sum 53.70602822303772 s - - shape = (10000, 1000) - brainpylib 1.7262670993804932 s - JAX argsort + sum 43.92174816131592 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.14670848846435547 s - JAX argsort + sum 1.001936435699463 s - - shape = (100, 1000000) - brainpylib 0.27660632133483887 s - JAX argsort + sum 16.390073776245117 s - - shape = (1000, 100000) - brainpylib 0.2619345188140869 s - JAX argsort + sum 9.715844869613647 s - - shape = (1000, 500000) - brainpylib 1.201209306716919 s - JAX argsort + sum 71.19761657714844 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: (jnp.argsort(events), jnp.sum(events)))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids2, event_num2 = jax_event_info(events) - assert jnp.allclose(event_num1, event_num2) - event_ids1.block_until_ready() - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a, b = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX argsort + sum {time() - t0} s') - - print() - - -def compare_argsort(platform='cpu'): - """ - - CPU - --- - - shape = (100, 10000) - brainpylib 0.19738531112670898 s - JAX argsort 5.301469087600708 s - - shape = (100, 100000) - brainpylib 2.3321938514709473 s - JAX argsort 65.13460850715637 s - - shape = (1000, 10000) - brainpylib 2.0956876277923584 s - JAX argsort 53.863110065460205 s - - shape = (10000, 1000) - brainpylib 1.7127799987792969 s - JAX argsort 44.05547475814819 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.1415419578552246 s - JAX argsort 0.9982438087463379 s - - shape = (100, 1000000) - brainpylib 0.3224947452545166 s - JAX argsort 16.504750967025757 s - - shape = (1000, 100000) - brainpylib 0.2781648635864258 s - JAX argsort 9.691488981246948 s - - shape = (1000, 500000) - brainpylib 1.2167487144470215 s - JAX argsort 71.68716263771057 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: jnp.argsort(events))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids1.block_until_ready() - event_ids2 = jax_event_info(events) - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX argsort {time() - t0} s') - - print() - - -def compare_where(platform='cpu'): - """ - - CPU - --- - - shape = (100, 10000) - brainpylib 0.20480966567993164 s - JAX where 0.7068588733673096 s - - shape = (100, 100000) - brainpylib 2.3373026847839355 s - JAX where 5.862265348434448 s - - shape = (1000, 10000) - brainpylib 2.105764865875244 s - JAX where 5.914586067199707 s - - shape = (10000, 1000) - brainpylib 1.724682331085205 s - JAX where 5.718563795089722 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.15492558479309082 s - JAX where 0.3146538734436035 s - - shape = (100, 1000000) - brainpylib 0.3290700912475586 s - JAX where 1.7064015865325928 s - - shape = (1000, 100000) - brainpylib 0.2895216941833496 s - JAX where 1.6910102367401123 s - - shape = (1000, 500000) - brainpylib 1.173649787902832 s - JAX where 7.868000268936157 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: jnp.where(events, size=events.shape[0]))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids1.block_until_ready() - event_ids2, = jax_event_info(events) - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a, = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX where {time() - t0} s') - - print() - - -if __name__ == '__main__': - # compare_argsort_and_sum('cpu') - # compare_argsort_and_sum('gpu') - # compare_argsort('cpu') - compare_argsort('gpu') - # compare_where('cpu') - # compare_where('gpu') diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index e0f38490f..67e09d0a4 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -4,11 +4,18 @@ from functools import partial import jax +import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + seed = 1234 diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_old.py b/brainpy/_src/math/event/tests/test_event_csrmv_old.py deleted file mode 100644 index 31a6527a2..000000000 --- a/brainpy/_src/math/event/tests/test_event_csrmv_old.py +++ /dev/null @@ -1,324 +0,0 @@ -# -*- coding: utf-8 -*- - - -from functools import partial - -import jax -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm -import platform - -import pytest -pytest.skip('Old implementation.', allow_module_level=True) - -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - -brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') -taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -class Test_event_csr_matvec(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_csr_matvec, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo(self, shape, transpose, homo_data): - print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - heter_data = bm.ones(indices.shape) * homo_data - - r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - r3 = brainpylib_csr_matvec(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r3)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r4)) - - r5 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r5)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo_vmap(self, shape, transpose, homo_data): - print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose, - method='cusparse')) - vmap_data1 = bm.as_jax([homo_data] * 10) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}', - homo_data=homo_data, - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo_grad(self, shape, transpose, homo_data): - print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else - ((dense_conn * a) @ events))))(homo_data) - self.assertTrue(bm.allclose(r1, r3)) - - # grad 'events' - r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else - ((dense_conn * homo_data) @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] - ) - def test_heter(self, shape, transpose): - print(f'test_heter: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - heter_data = bm.as_jax(rng.random(indices.shape)) - - r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r3)) - - r4 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r4)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f"transpose={transpose}, shape={shape}", - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - ) - def test_heter_vmap(self, shape, transpose): - print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'transpose={transpose},shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - ) - def test_heter_grad(self, shape, transpose): - print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape) - r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else - (a @ events))))(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r3 = r3[rows, cols] - self.assertTrue(bm.allclose(r1, r3)) - - # grad 'events' - r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else - (dense_data @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/event/tests/test_info.py b/brainpy/_src/math/event/tests/test_info.py deleted file mode 100644 index c326b0f76..000000000 --- a/brainpy/_src/math/event/tests/test_info.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp -import unittest - -import brainpy.math as bm -from jax import vmap - -import pytest - - -class Test_event_info(unittest.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_info, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - def _base_test(self, length): - print(f'{self._base_test.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(length)) < 0.1 - event_ids, event_num = bm.event.info(events) - self.assertTrue(jnp.allclose(jnp.sum(events, keepdims=True), event_num)) - - bm.clear_buffer_memory() - - def _base_vmap(self, length): - print(f'{self._base_vmap.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, length))) < 0.1 - event_ids, event_num = vmap(bm.event.info)(events) - self.assertTrue(jnp.allclose(jnp.sum(events, axis=-1), event_num)) - - bm.clear_buffer_memory() - - def _base_vmap_vmap(self, length): - print(f'{self._base_vmap_vmap.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, length))) < 0.1 - event_ids, event_num = vmap(vmap(bm.event.info))(events) - self.assertTrue(jnp.allclose(jnp.sum(events, axis=-1), event_num)) - - bm.clear_buffer_memory() - - def test(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - def test_vmap(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - def test_vmap_vmap(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - - diff --git a/brainpy/_src/math/event/tests/test_info_gpu.py b/brainpy/_src/math/event/tests/test_info_gpu.py deleted file mode 100644 index 55bdd15cd..000000000 --- a/brainpy/_src/math/event/tests/test_info_gpu.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_info - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_event_info_GPU(test_info.Test_event_info): - def __init__(self, *args, **kwargs): - super(Test_event_info_GPU, self).__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/index_tricks.py b/brainpy/_src/math/index_tricks.py deleted file mode 100644 index 6c71b4b06..000000000 --- a/brainpy/_src/math/index_tricks.py +++ /dev/null @@ -1,305 +0,0 @@ -# -*- coding: utf-8 -*- - -import abc - -from jax import core -from .compat_numpy import arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose -import numpy as np - -__all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] - - -def _make_1d_grid_from_slice(s: slice, op_name: str): - start = core.concrete_or_error(None, s.start, - f"slice start of jnp.{op_name}") or 0 - stop = core.concrete_or_error(None, s.stop, - f"slice stop of jnp.{op_name}") - step = core.concrete_or_error(None, s.step, - f"slice step of jnp.{op_name}") or 1 - if np.iscomplex(step): - newobj = linspace(start, stop, int(abs(step))) - else: - newobj = arange(start, stop, step) - - return newobj - - -class _IndexGrid(abc.ABC): - """Creates multi-dimensional grids of indices.""" - sparse: bool - op_name: str - - def __getitem__(self, key): - if isinstance(key, slice): - return _make_1d_grid_from_slice(key, op_name=self.op_name) - output = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key) - output = meshgrid(*output, indexing='ij', sparse=self.sparse) - return output if self.sparse else stack(output, 0) - - -class _Mgrid(_IndexGrid): - """Return dense multi-dimensional "meshgrid". - - LAX-backend implementation of :obj:`numpy.mgrid`. This is a convenience wrapper for - functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=False``. - - See Also: - jnp.ogrid: open/sparse version of jnp.mgrid - - Examples: - Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - - >>> import brainpy.math as bm - >>> bm.mgrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) - - Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - - >>> bm.mgrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) - - Multiple slices can be used to create broadcasted grids of indices: - - >>> bm.mgrid[:2, :3] - DeviceArray([[[0, 0, 0], - [1, 1, 1]], - [[0, 1, 2], - [0, 1, 2]]], dtype=int32) - """ - sparse = False - op_name = "mgrid" - - -mgrid = _Mgrid() - - -class _Ogrid(_IndexGrid): - """Return open multi-dimensional "meshgrid". - - LAX-backend implementation of :obj:`numpy.ogrid`. This is a convenience wrapper for - functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=True``. - - See Also: - jnp.mgrid: dense version of jnp.ogrid - - Examples: - Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - - >>> bm.ogrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) - - Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - - >>> bm.ogrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) - - Multiple slices can be used to create sparse grids of indices: - - >>> bm.ogrid[:2, :3] - [DeviceArray([[0], - [1]], dtype=int32), - DeviceArray([[0, 1, 2]], dtype=int32)] - """ - sparse = True - op_name = "ogrid" - - -ogrid = _Ogrid() - - -class _AxisConcat(abc.ABC): - """Concatenates slices, scalars and array-like objects along a given axis.""" - axis: int - ndmin: int - trans1d: int - op_name: str - - def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - - params = [self.axis, self.ndmin, self.trans1d, -1] - - if isinstance(key[0], str): - # split off the directive - directive, *key = key # pytype: disable=bad-unpacking - # check two special cases: matrix directives - if directive == "r": - params[-1] = 0 - elif directive == "c": - params[-1] = 1 - else: - vec = directive.split(",") - k = len(vec) - if k < 4: - vec += params[k:] - else: - # ignore everything after the first three comma-separated ints - vec = vec[:3] + params[-1] - try: - params = list(map(int, vec)) - except ValueError as err: - raise ValueError( - "could not understand directive {!r}".format(directive) - ) from err - - axis, ndmin, trans1d, matrix = params - - output = [] - for item in key: - if isinstance(item, slice): - newobj = _make_1d_grid_from_slice(item, op_name=self.op_name) - elif isinstance(item, str): - raise ValueError("string directive must be placed at the beginning") - else: - newobj = item - - newobj = array(newobj, copy=False, ndmin=ndmin) - - if trans1d != -1 and ndmin - np.ndim(item) > 0: - shape_obj = list(range(ndmin)) - # Calculate number of left shifts, with overflow protection by mod - num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin - shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) - - newobj = transpose(newobj, shape_obj) - - output.append(newobj) - - res = concatenate(tuple(output), axis=axis) - - if matrix != -1 and res.ndim == 1: - # insert 2nd dim at axis 0 or 1 - res = expand_dims(res, matrix) - - return res - - def __len__(self): - return 0 - - -class RClass(_AxisConcat): - """Concatenate slices, scalars and array-like objects along the first axis. - - LAX-backend implementation of :obj:`numpy.r_`. - - See Also: - ``jnp.c_``: Concatenates slices, scalars and array-like objects along the last axis. - - Examples: - Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects: - - >>> bm.r_[-1:5:1, 0, 0, bm.array([1,2,3])] - DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32) - - An imaginary value for ``step`` will create a ``jnp.linspace`` object instead, - which includes the right endpoint: - - >>> bm.r_[-1:1:6j, 0, bm.array([1,2,3])] - DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005, - 0.6 , 1. , 0. , 1. , - 2. , 3. ], dtype=float32) - - Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to - specify concatenation axis, minimum number of dimensions, and the position of the - upgraded array's original dimensions in the resulting array's shape tuple: - - >>> bm.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) - - >>> bm.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - Negative values for ``trans1d`` offset the last axis towards the start - of the shape tuple: - - >>> bm.r_['0,2,-2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs - to create an array with an extra row or column axis, respectively: - - >>> bm.r_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32) - - >>> bm.r_['c',[1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"`` - give the same result. - """ - axis = 0 - ndmin = 1 - trans1d = -1 - op_name = "r_" - - -r_ = RClass() - - -class CClass(_AxisConcat): - """Concatenate slices, scalars and array-like objects along the last axis. - - LAX-backend implementation of :obj:`numpy.c_`. - - See Also: - ``jnp.r_``: Concatenates slices, scalars and array-like objects along the first axis. - - Examples: - - >>> a = bm.arange(6).reshape((2,3)) - >>> bm.c_[a,a] - DeviceArray([[0, 1, 2, 0, 1, 2], - [3, 4, 5, 3, 4, 5]], dtype=int32) - - Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify - concatenation axis, minimum number of dimensions, and the position of the upgraded array's - original dimensions in the resulting array's shape tuple: - - >>> bm.c_['0,2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - >>> bm.c_['0,2,-1', [1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) - - Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs - to create an array with inputs stacked along the last axis: - - >>> jnp.c_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 4], - [2, 5], - [3, 6]], dtype=int32) - """ - axis = -1 - ndmin = 2 - trans1d = 0 - op_name = "c_" - - -c_ = CClass() - -s_ = np.s_ - -index_exp = np.index_exp diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index a79cdc982..6f7cddf6a 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,3 +1,2 @@ - -from ._matvec import * -from ._event_matvec import * \ No newline at end of file +from ._matvec import * +from ._event_matvec import * diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 3671755a9..976b72b96 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -1,21 +1,14 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Tuple, Optional import jax import numpy as np -from jax import numpy as jnp, dtypes -from jax.core import ShapedArray, Primitive -from jax.interpreters import xla, ad -from jax.lib import xla_client +from jax import numpy as jnp -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p, - mv_prob_uniform_p, - mv_prob_normal_p, - mv_prob_homo, +from brainpy._src.math.jitconn._matvec import (mv_prob_homo, mv_prob_uniform, mv_prob_normal, _general_checking, @@ -27,11 +20,10 @@ _mv_prob_normal_transpose, _reverse) from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import register_general_batching, XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.op_register import XLACustomOp +from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'event_mv_prob_homo', @@ -50,746 +42,9 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -### BRAINPYLIB ### - -def event_mv_prob_homo_brainpylib( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - weight = jnp.atleast_1d(as_jax(weight)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - return r - - -event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform_brainpylib( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal_brainpylib( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ - - -def _event_matvec_prob_homo_abstract( - events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_homo_cpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_homo' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_homo_gpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_homo_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1], ) - - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, weight, clen, seed = primals - event_dot, weight_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(weight_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(event_dot) is ad.Zero: - raise ValueError - dr = mv_prob_homo_p.bind(event_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(event_dot) is ad.Zero: - dr = mv_prob_homo_p.bind(events, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - dr = mv_prob_homo_p.bind(event_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, dr - - -def _event_matvec_prob_homo_transpose( - ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -event_mv_prob_homo_p = Primitive('event_mv_prob_homo') -event_mv_prob_homo_p.multiple_results = True -event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) -event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation -ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp -ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose -register_general_batching(event_mv_prob_homo_p) - - -def _event_matvec_prob_uniform_abstract( - events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_uniform_cpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_uniform_gpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_low, w_high, clen, seed = primals - events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - outdim_parallel=outdim_parallel, - transpose=transpose) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(events_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_uniform_transpose( - ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') -event_mv_prob_uniform_p.multiple_results = True -event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) -event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation -register_general_batching(event_mv_prob_uniform_p) -ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp -ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose - - -def _event_matvec_prob_normal_abstract( - events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - assert w_mu.dtype == w_sigma.dtype - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _get_types(event_shape): - event_type = event_shape.element_type() - if event_type == jnp.bool_: - event_type = b'_bool' - out_dtype = dtypes.canonicalize_dtype(float) - elif event_type == jnp.float32: - event_type = b'_float' - out_dtype = event_shape.element_type() - elif event_type == jnp.float64: - event_type = b'_double' - out_dtype = event_shape.element_type() - else: - raise TypeError - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - return out_dtype, event_type, type_name - - -def _event_matvec_prob_normal_cpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_normal' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_normal_gpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_normal_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_mu, w_sigma, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_mu, w_sigma, clen, seed = primals - events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(events_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_normal_transpose( - ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -event_mv_prob_normal_p = Primitive('event_mv_prob_normal') -event_mv_prob_normal_p.multiple_results = True -event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) -event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation -register_general_batching(event_mv_prob_normal_p) -ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp -ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose - - -### TAICHI ### - -def event_mv_prob_homo_taichi( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ events = as_jax(events) if isinstance(weight, float): weight = as_jax(weight) weight = jnp.atleast_1d(as_jax(weight)) @@ -799,11 +54,16 @@ def event_mv_prob_homo_taichi( with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return raw_event_mv_prob_homo(events, weight, conn_len, seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ -def event_mv_prob_uniform_taichi( +def event_mv_prob_uniform( events: jax.Array, w_low: float, w_high: float, @@ -814,56 +74,9 @@ def event_mv_prob_uniform_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + events = as_jax(events) if isinstance(w_low, float): w_low = as_jax(w_low) if isinstance(w_high, float): w_high = as_jax(w_high) @@ -879,7 +92,10 @@ def event_mv_prob_uniform_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def event_mv_prob_normal_taichi( +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( events: jax.Array, w_mu: float, w_sigma: float, @@ -890,56 +106,9 @@ def event_mv_prob_normal_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + events = as_jax(events) if isinstance(w_mu, float): w_mu = as_jax(w_mu) if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) @@ -955,1034 +124,1036 @@ def event_mv_prob_normal_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 +event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ + +if ti is not None: + from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + + + # ------------- + # CPU function + # ------------- + # For each non-zero event value, it generates a random key using a + # function lfsr88_key and then uses this key to compute random integers + # and update the out array based on the computed indices and weight. + # + # The function is likely designed to be parallelized. + + @ti.kernel + def _event_mv_prob_homo_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col]: + r += weight0 key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: + i_col += inc + out[i_row] = r + + + # ------------- + # GPU function + # ------------- + # Contrary to the CPU functions, for each column, + # this function will 32 threads (one warp) to make + # the just-in-time random generation parallelized. + + @ti.kernel + def _event_mv_prob_homo_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison without if else + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _reverse(shape): - return shape[::-1] - - -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _reverse(shape): + return shape[::-1] + + + # ------------- + # CPU function + # ------------- + # For each non-zero event value, it generates a random key using a + # function lfsr88_key and then uses this key to compute random integers + # and update the out array based on the computed indices and weight. + # + # The function is likely designed to be parallelized. + + @ti.kernel + def _event_mv_prob_homo_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): if events[i_col] != 0.: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col] != 0.: + r += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + + # ------------- + # GPU function + # ------------- + # Contrary to the CPU functions, for each column, + # this function will 32 threads (one warp) to make + # the just-in-time random generation parallelized. + + @ti.kernel + def _event_mv_prob_homo_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 + i_col += inc + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + i_col += inc + out[i_row] += r # TODO: warp-level reduction -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction + def _event_mv_prob_homo_jvp_events( + evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(evt_dot, weight, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _event_mv_prob_homo_jvp_events( - evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(evt_dot, weight, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _event_mv_prob_homo_jvp_weight( + w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(events, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _event_mv_prob_homo_jvp_weight( - w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(events, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + + def raw_event_mv_prob_homo( + events: jax.Array, + weight: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_outdim_parallel_bool_p + else: + prim = _event_mv_prob_homo_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_bool_p + else: + prim = _event_mv_prob_homo_p + + return prim(events, + weight, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_homo_jvp_events, + _event_mv_prob_homo_jvp_weight, + None, + None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_bool_cpu, + gpu_kernel=_event_mv_prob_homo_bool_gpu + ) -def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu + ) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_cpu, + gpu_kernel=_event_mv_prob_homo_gpu + ) -def raw_event_mv_prob_homo( - events: jax.Array, - weight: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_outdim_parallel_bool_p - else: - prim = _event_mv_prob_homo_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_bool_p - else: - prim = _event_mv_prob_homo_p - - return prim(events, - weight, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_homo_jvp_events, - _event_mv_prob_homo_jvp_weight, - None, - None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_bool_cpu, - gpu_kernel=_event_mv_prob_homo_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_cpu, - gpu_kernel=_event_mv_prob_homo_gpu -) - - -@ti.kernel -def _event_mv_prob_uniform_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + @ti.kernel + def _event_mv_prob_uniform_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + if events[i_col]: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_uniform_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_uniform_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_uniform_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + if events[i_col] != 0.: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_uniform_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_uniform_jvp_events( - evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_uniform_jvp_w_low( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _event_mv_prob_uniform_jvp_events( + evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_uniform_jvp_w_low( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_uniform_jvp_w_high( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def raw_event_mv_prob_uniform( + events: jax.Array, + w_low: jax.Array, # vector with size 1 + w_high: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_outdim_parallel_bool_p + else: + prim = _event_mv_prob_uniform_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_bool_p + else: + prim = _event_mv_prob_uniform_p + + return prim(events, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_uniform_jvp_events, + _event_mv_prob_uniform_jvp_w_low, + _event_mv_prob_uniform_jvp_w_high, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_bool_gpu + ) -def _event_mv_prob_uniform_jvp_w_high( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu + ) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_cpu, + gpu_kernel=_event_mv_prob_uniform_gpu + ) -def raw_event_mv_prob_uniform( - events: jax.Array, - w_low: jax.Array, # vector with size 1 - w_high: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_outdim_parallel_bool_p - else: - prim = _event_mv_prob_uniform_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_bool_p - else: - prim = _event_mv_prob_uniform_p - - return prim(events, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_uniform_jvp_events, - _event_mv_prob_uniform_jvp_w_low, - _event_mv_prob_uniform_jvp_w_high, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_cpu, - gpu_kernel=_event_mv_prob_uniform_gpu -) - - -@ti.kernel -def _event_mv_prob_normal_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + @ti.kernel + def _event_mv_prob_normal_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + if events[i_col]: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_normal_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_normal_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_normal_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + if events[i_col] != 0.: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_normal_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_normal_jvp_events( - evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_mu( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_sigma( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _event_mv_prob_normal_jvp_events( + evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_normal_jvp_w_mu( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_normal_jvp_w_sigma( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def raw_event_mv_prob_normal( + events: jax.Array, + w_mu: jax.Array, # vector with size 1 + w_sigma: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_outdim_parallel_bool_p + else: + prim = _event_mv_prob_normal_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_bool_p + else: + prim = _event_mv_prob_normal_p + + return prim(events, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_normal_jvp_events, + _event_mv_prob_normal_jvp_w_mu, + _event_mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_bool_cpu, + gpu_kernel=_event_mv_prob_normal_bool_gpu + ) -def raw_event_mv_prob_normal( - events: jax.Array, - w_mu: jax.Array, # vector with size 1 - w_sigma: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu + ) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_outdim_parallel_bool_p - else: - prim = _event_mv_prob_normal_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_bool_p - else: - prim = _event_mv_prob_normal_p - - return prim(events, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_normal_jvp_events, - _event_mv_prob_normal_jvp_w_mu, - _event_mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_bool_cpu, - gpu_kernel=_event_mv_prob_normal_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_cpu, - gpu_kernel=_event_mv_prob_normal_gpu -) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_cpu, + gpu_kernel=_event_mv_prob_normal_gpu + ) diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 0caa9c996..00e5778f9 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -1,24 +1,20 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Tuple, Optional, Union import jax import numpy as np -from jax import numpy as jnp, dtypes -from jax.core import ShapedArray, Primitive -from jax.interpreters import xla, ad -from jax.lib import xla_client +from jax import numpy as jnp +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_register import register_general_batching, XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.op_register import XLACustomOp +from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'mv_prob_homo', @@ -85,8 +81,22 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + vector = as_jax(vector) + if isinstance(weight, float): + weight = as_jax(weight, dtype=vector.dtype) + weight = jnp.atleast_1d(as_jax(weight)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.asarray(seed, dtype=jnp.uint32) + seed = jnp.atleast_1d(seed) + return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] def mv_prob_uniform( @@ -150,8 +160,22 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + vector = as_jax(vector) + if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) + if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] def mv_prob_normal( @@ -215,1188 +239,110 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + vector = as_jax(vector) + if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) + if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] -### BRAINYPLIB ### -def mv_prob_homo_brainpylib( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, +def raw_mv_prob_homo( + vector: jax.Array, + weight: jax.Array, # vector with size 1 + clen: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 *, shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: + mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). + if outdim_parallel: + prim = _mv_prob_homo_outdim_parallel_p + else: + prim = _mv_prob_homo_p - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. + return prim(vector, + weight, + clen, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - weight = jnp.atleast_1d(as_jax(weight)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_homo_p.bind(vector, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - )[0] - - -def mv_prob_uniform_brainpylib( +def raw_mv_prob_uniform( vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, *, shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). + if outdim_parallel: + prim = _mv_prob_uniform_outdim_parallel_p + else: + prim = _mv_prob_uniform_p - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. + return prim(vector, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_uniform_p.bind(vector, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def mv_prob_normal_brainpylib( +def raw_mv_prob_normal( vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, *, shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. + if outdim_parallel: + prim = _mv_prob_normal_outdim_parallel_p + else: + prim = _mv_prob_normal_p - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. + return prim(vector, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_normal_p.bind(vector, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def _matvec_prob_homo_abstract( - vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - if transpose: - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({vector.shape[0]},) @ mat {shape}.') - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_homo_cpu_translation( - c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - out_type = b'_float' - elif out_dtype == jnp.float64: - out_type = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_homo' + out_type - else: - fn = b'cpu_matvec_atomic_prob_homo' + out_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_homo_gpu_translation( - c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_homo_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_homo_v2' + type_name - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, weight, clen, seed = primals - vector_dot, weight_dot, clen_dot, seed_dot = tangents - r = mv_prob_homo_p.bind(vector, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(vector_dot) is ad.Zero: - raise ValueError - r_dot = mv_prob_homo_p.bind(vector_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(vector_dot) is ad.Zero: - r_dot = mv_prob_homo_p.bind(vector, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - r_dot = mv_prob_homo_p.bind(vector_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - return r, r_dot - - -def _matvec_prob_homo_transpose( - ct, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - assert type(vector) is ad.UndefinedPrimal - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -mv_prob_homo_p = Primitive('matvec_prob_homo') -mv_prob_homo_p.multiple_results = True -mv_prob_homo_p.def_abstract_eval(_matvec_prob_homo_abstract) -mv_prob_homo_p.def_impl(partial(xla.apply_primitive, mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation -register_general_batching(mv_prob_homo_p) -ad.primitive_jvps[mv_prob_homo_p] = _matvec_prob_homo_jvp -ad.primitive_transposes[mv_prob_homo_p] = _matvec_prob_homo_transpose - - -def _matvec_prob_uniform_abstract( - vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype == vector.dtype - - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_uniform_cpu_translation( - c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_uniform' + type_name - else: - fn = b'cpu_matvec_atomic_prob_uniform' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_uniform_gpu_translation( - c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError(f'Only support float or double, while got {out_dtype}') - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_uniform_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_uniform_v2' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, w_low, w_high, clen, seed = primals - vector_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = mv_prob_uniform_p.bind(vector, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(vector_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _matvec_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(vector) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -mv_prob_uniform_p = Primitive('matvec_prob_uniform') -mv_prob_uniform_p.multiple_results = True -mv_prob_uniform_p.def_abstract_eval(_matvec_prob_uniform_abstract) -mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation -register_general_batching(mv_prob_uniform_p) -ad.primitive_jvps[mv_prob_uniform_p] = _matvec_prob_uniform_jvp -ad.primitive_transposes[mv_prob_uniform_p] = _matvec_prob_uniform_transpose - - -def _matvec_prob_normal_abstract( - vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _w_sigma_dtype in [jnp.float32, jnp.float64], '"w_sigma" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_normal_cpu_translation( - c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_normal' + type_name - else: - fn = b'cpu_matvec_atomic_prob_normal' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_normal_gpu_translation( - c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - event_shape = c.get_shape(vector) - out_dtype = event_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError(f'Only support float or double, while got {out_dtype}') - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_normal_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_normal_v2' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_mu, - w_sigma, - clen, - seed,), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, w_mu, w_sigma, clen, seed = primals - vector_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = mv_prob_normal_p.bind(vector, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(vector_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _matvec_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(vector) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -mv_prob_normal_p = Primitive('matvec_prob_normal') -mv_prob_normal_p.multiple_results = True -mv_prob_normal_p.def_abstract_eval(_matvec_prob_normal_abstract) -mv_prob_normal_p.def_impl(partial(xla.apply_primitive, mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation -register_general_batching(mv_prob_normal_p) -ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp -ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose - - -### TAICHI ### -def mv_prob_homo_taichi( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of - the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``. - - Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same - of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(weight, float): - weight = as_jax(weight, dtype=vector.dtype) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.asarray(seed, dtype=jnp.uint32) - seed = jnp.atleast_1d(seed) - return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def mv_prob_uniform_taichi( - vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) - if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def mv_prob_normal_taichi( - vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _reverse(shape): - return shape[::-1] - - -@ti.kernel -def _mv_prob_homo_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - v = vector[i_col] * weight0 - while i_row < num_row: - out[i_row] += v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r * weight0 - - -@ti.kernel -def _mv_prob_homo_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += weight0 * r # TODO: warp-level reduction - - -def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed - else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed - else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): +def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): if vector.ndim != 1: raise ValueError('vector should be a 1D vector.') if len(shape) != 2: @@ -1437,190 +383,28 @@ def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, * return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) -def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) - -# outdim_parallel = False -_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, - gpu_kernel=_mv_prob_homo_gpu) - - -@ti.kernel -def _mv_prob_uniform_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_uniform_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_uniform_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_uniform_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) +def _mv_prob_homo_transpose( + ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), weight, clen, seed + else: + dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, weight, clen, seed + elif ad.is_undefined_primal(weight): + if type(ct) is ad.Zero: + return vector, ad.Zero(weight), clen, seed + else: + row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + dw = jnp.sum(row * vector, keepdims=True) + return vector, dw, clen, seed + else: + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' def _mv_prob_uniform_transpose( @@ -1641,265 +425,496 @@ def _mv_prob_uniform_transpose( assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' -def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p +def _mv_prob_normal_transpose( + ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_mu, w_sigma, clen, seed + else: + dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_mu, w_sigma, clen, seed else: - prim = _mv_prob_uniform_p + assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' + assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) +def _reverse(shape): + return shape[::-1] -def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_uniform_jvp_vector, - _mv_prob_uniform_jvp_wlow, - _mv_prob_uniform_jvp_whigh, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_uniform_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_cpu, - gpu_kernel=_mv_prob_uniform_gpu -) - - -@ti.kernel -def _mv_prob_normal_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += col_v * raw_v + +if ti is not None: + from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + + + @ti.kernel + def _mv_prob_homo_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + v = vector[i_col] * weight0 + while i_row < num_row: + out[i_row] += v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_homo_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r * weight0 + + + @ti.kernel + def _mv_prob_homo_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * raw_v + while i_row < end: + out[i_row] += weight0 * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_homo_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] = r + while i_col < end_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += weight0 * r # TODO: warp-level reduction -@ti.kernel -def _mv_prob_normal_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) -@ti.kernel -def _mv_prob_normal_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * row_v + def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) + + # outdim_parallel = False + _mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, + gpu_kernel=_mv_prob_homo_gpu) + + + @ti.kernel + def _mv_prob_uniform_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_uniform_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _mv_prob_uniform_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_uniform_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += r # TODO: warp-level reduction + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction -def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed - else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_uniform_jvp_vector, + _mv_prob_uniform_jvp_wlow, + _mv_prob_uniform_jvp_whigh, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim -def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + # outdim_parallel = True + _mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu + ) - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p + # outdim_parallel = False + _mv_prob_uniform_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_cpu, + gpu_kernel=_mv_prob_uniform_gpu + ) - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) + @ti.kernel + def _mv_prob_normal_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_normal_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _mv_prob_normal_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_normal_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_normal_jvp_vector, - _mv_prob_normal_jvp_w_mu, - _mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_normal_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_cpu, - gpu_kernel=_mv_prob_normal_gpu -) + + def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_normal_jvp_vector, + _mv_prob_normal_jvp_w_mu, + _mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_mv_prob_normal_outdim_parallel_gpu + ) + + # outdim_parallel = False + _mv_prob_normal_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_cpu, + gpu_kernel=_mv_prob_normal_gpu + ) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index b10d55d21..d8e086540 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -4,8 +4,14 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +import pytest import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 2e6e406cf..8a0ae444d 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -4,8 +4,13 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +import pytest import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index ad8a5ccf6..f5e091675 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -28,8 +28,10 @@ get_stack_cache, cache_stack) from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, VariableStack) -from .tools import eval_shape +from .variables import (Variable, + VariableStack, + current_transform_number, + new_transform) __all__ = [ 'grad', # gradient of scalar function @@ -201,21 +203,36 @@ def __call__(self, *args, **kwargs): elif not self._eval_dyn_vars: # evaluate dynamical variables stack = get_stack_cache(self.target) if stack is None: - with VariableStack() as stack: - rets = eval_shape(self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs) + with new_transform(self): + with VariableStack() as stack: + if current_transform_number() > 1: + rets = self._transform( + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) + else: + rets = jax.eval_shape( + self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if not stack.is_first_stack(): - return self._return(rets) + # if not the outermost transformation + if current_transform_number(): + return self._return(rets) + else: + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index c52845a06..aaf053ae7 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -6,6 +6,7 @@ """ import numbers +import os import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional @@ -13,13 +14,14 @@ import jax import numpy as np -from brainpy._src.math.modes import Mode +from brainpy import errors from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) +from brainpy._src.math.modes import Mode from brainpy._src.math.sharding import BATCH_AXIS variable_ = None diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 3edeb08e8..032a0fab6 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -21,12 +21,17 @@ cache_stack ) from .tools import ( - eval_shape, + evaluate_dyn_vars, dynvar_deprecation, node_deprecation, abstract ) -from .variables import (Variable, VariableStack) +from .variables import ( + Variable, + VariableStack, + new_transform, + current_transform_number, +) __all__ = [ 'make_loop', @@ -537,13 +542,15 @@ def cond( node_deprecation(child_objs) dyn_vars = get_stack_cache((true_fun, false_fun)) - if not jax.config.jax_disable_jit and dyn_vars is None: - with VariableStack() as dyn_vars: - rets = eval_shape(true_fun, *operands, with_stack=True)[1] - _ = eval_shape(false_fun, *operands, with_stack=True) - cache_stack((true_fun, false_fun), dyn_vars) - if not dyn_vars.is_first_stack(): - return rets + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('cond'): + dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars = dyn_vars1 + dyn_vars2 + cache_stack((true_fun, false_fun), dyn_vars) + if current_transform_number() > 0: + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -674,16 +681,20 @@ def ifelse( else: dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: - with VariableStack() as dyn_vars: - rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] - trees = [jax.tree_util.tree_structure(ret) for ret in rets] - if not _all_equal(trees): - msg = 'All returns in branches should have the same tree structure. But we got:\n' - for tree in trees: - msg += f'- {tree}\n' - raise TypeError(msg) + with new_transform('ifelse'): + with VariableStack() as dyn_vars: + if current_transform_number() > 1: + rets = [branch(*operands) for branch in branches] + else: + rets = [jax.eval_shape(branch, *operands) for branch in branches] + trees = [jax.tree_util.tree_structure(ret) for ret in rets] + if not _all_equal(trees): + msg = 'All returns in branches should have the same tree structure. But we got:\n' + for tree in trees: + msg += f'- {tree}\n' + raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) - if not dyn_vars.is_first_stack(): + if current_transform_number(): return rets[0] branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches] @@ -869,23 +880,28 @@ def for_loop( if jit is None: # jax disable jit jit = not jax.config.jax_disable_jit - stack = get_stack_cache((body_fun, unroll_kwargs)) + dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) if jit: - if stack is None: - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, - remat, reverse, unroll, unroll_kwargs) + if dyn_vars is None: # TODO: better cache mechanism? - with VariableStack() as stack: - rets = eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), stack) # cache - if not stack.is_first_stack(): + with new_transform('for_loop'): + with VariableStack() as dyn_vars: + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, + progress_bar, remat, reverse, unroll, + unroll_kwargs) + if current_transform_number() > 1: + rets = transform(operands) + else: + rets = jax.eval_shape(transform, operands) + cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache + if current_transform_number(): return rets[1] del rets else: - stack = VariableStack() + dyn_vars = VariableStack() # TODO: cache mechanism? - transform = _get_for_loop_transform(body_fun, stack, bar, + transform = _get_for_loop_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll, unroll_kwargs) if jit: @@ -893,11 +909,11 @@ def for_loop( else: with jax.disable_jit(): dyn_vals, out_vals = transform(operands) - for key in stack.keys(): - stack[key]._value = dyn_vals[key] + for key in dyn_vars.keys(): + dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() - del dyn_vals, stack + del dyn_vals, dyn_vars return out_vals @@ -995,21 +1011,26 @@ def scan( num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) bar = tqdm(total=num_total) - stack = get_stack_cache(body_fun) - if not jax.config.jax_disable_jit and stack is None: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - with VariableStack() as stack: - rets = eval_shape(transform, init, operands) - cache_stack(body_fun, stack) # cache - if not stack.is_first_stack(): - return rets[0][1], rets[1] - del rets - - stack = VariableStack() if stack is None else stack - transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) + dyn_vars = get_stack_cache(body_fun) + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('scan'): + with VariableStack() as dyn_vars: + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + if current_transform_number() > 1: + rets = transform(init, operands) + else: + rets = jax.eval_shape(transform, init, operands) + cache_stack(body_fun, dyn_vars) # cache + if current_transform_number(): + return rets[0][1], rets[1] + del rets + + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars + transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) (dyn_vals, carry), out_vals = transform(init, operands) - for key in stack.keys(): - stack[key]._value = dyn_vals[key] + for key in dyn_vars.keys(): + dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() return carry, out_vals @@ -1108,6 +1129,7 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. + """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1115,16 +1137,18 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - stack = get_stack_cache((body_fun, cond_fun)) - if not jax.config.jax_disable_jit and stack is None: - with VariableStack() as stack: - _ = eval_shape(cond_fun, *operands, with_stack=True) - rets = eval_shape(body_fun, *operands, with_stack=True)[1] - cache_stack((body_fun, cond_fun), stack) - if not stack.is_first_stack(): - return rets - stack = VariableStack() if stack is None else stack - dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) - for k, v in stack.items(): + dyn_vars = get_stack_cache((body_fun, cond_fun)) + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('while_loop'): + dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars = dyn_vars1 + dyn_vars2 + cache_stack((body_fun, cond_fun), dyn_vars) + if current_transform_number(): + return rets + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars + dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) + for k, v in dyn_vars.items(): v._value = dyn_values[k] return out diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 73eab2f91..7bb36f4e2 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,15 +11,23 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax +from jax.sharding import Sharding from brainpy import tools, check -from .base import BrainPyObject, ObjectTransform -from .naming import get_stack_cache, cache_stack from .tools import (dynvar_deprecation, node_deprecation, - eval_shape) -from .variables import (Variable, VariableStack) + evaluate_dyn_vars_with_cache, + evaluate_dyn_vars, + _partial_fun) +from .base import BrainPyObject, ObjectTransform +from .naming import get_stack_cache, cache_stack from ..ndarray import Array +from .variables import (Variable, + VariableStack, + outermost_transform, + transform_stack, + current_transform_number, + new_transform) RandomState = None @@ -143,12 +151,16 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): return changes, out def _get_transform(self, *args, **kwargs): - with VariableStack() as self._dyn_vars: - rets = eval_shape(self.fun, - *args, - **kwargs, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames) + with new_transform(self): + self._dyn_vars, rets = evaluate_dyn_vars( + self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + use_eval_shape=current_transform_number() <= 1, + **kwargs + ) + # in_shardings if self._in_shardings is None: in_shardings = None @@ -174,18 +186,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs): @@ -195,7 +207,7 @@ def __call__(self, *args, **kwargs): if self._transform is None: # initialize the transformation rets = self._get_transform(*args, **kwargs) # if not the outermost transformation - if not self._dyn_vars.is_first_stack(): + if current_transform_number(): return rets # call the transformed function @@ -465,8 +477,15 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - with VariableStack() as stack: - _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + + with jax.ensure_compile_time_eval(): + if len(static_argnums) or len(static_argnames): + fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) + else: + args_, kwargs_, fun3 = args, kwargs, fun2 + with VariableStack() as stack: + _ = jax.eval_shape(fun3, *args_, **kwargs_) + del args_, kwargs_ _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1181e003b..1c8ca6ef9 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -41,7 +41,7 @@ def get_unique_name(type_: str): return name -def clear_name_cache(ignore_warn=True): +def clear_name_cache(ignore_warn=False): """Clear the cached names.""" _name2id.clear() _typed_names.clear() @@ -57,7 +57,6 @@ def cache_stack(func, stack): def clear_stack_cache(): - """Clear the cached stack.""" for k in tuple(_fun2stack.keys()): del _fun2stack[k] diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py new file mode 100644 index 000000000..1eddce048 --- /dev/null +++ b/brainpy/_src/math/object_transform/parallels.py @@ -0,0 +1,460 @@ +# -*- coding: utf-8 -*- + +""" +The parallel compilation tools for JAX backend. + +1. Vectorize compilation is implemented by the 'vmap()' function +2. Parallel compilation is implemented by the 'pmap()' function + +""" + + +import functools + +import jax +import jax.numpy as jnp +import numpy as np +from jax.interpreters.partial_eval import DynamicJaxprTracer +from jax.interpreters.partial_eval import JaxprTracer +from jax.interpreters.pxla import ShardedDeviceArray + +try: + from jax.errors import UnexpectedTracerError +except ImportError: + from jax.core import UnexpectedTracerError + +from brainpy import errors +from brainpy._src.math.random import RandomState +from brainpy._src.math.ndarray import Array +from brainpy.tools import change_func_name +from .base import BrainPyObject, ArrayCollector + +__all__ = [ + 'vmap', + 'pmap', +] + + +def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, + batch_idx, axis_name, f_name=None): + @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) + def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): + nonbatched_vars.assign(nonbatched_data) + batched_vars.assign(batched_data) + out = func(*args, **kwargs) + nonbatched_changes = nonbatched_vars.dict() + batched_changes = batched_vars.dict() + return nonbatched_changes, batched_changes, out + + def call(*args, **kwargs): + n = args[batch_idx[0]].shape[batch_idx[1]] + nonbatched_data = nonbatched_vars.dict() + batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} + try: + out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) + except UnexpectedTracerError as e: + nonbatched_vars.assign(nonbatched_data) + batched_vars.assign(batched_data) + raise errors.JaxTracerError() from e + # for key, v in dyn_changes.items(): + # dyn_vars[key] = reduce_func(v) + # for key, v in rand_changes.items(): + # rand_vars[key] = reduce_func(v) + return out + + return change_func_name(name=f_name, f=call) if f_name else call + + +def vmap(func, dyn_vars=None, batched_vars=None, + in_axes=0, out_axes=0, axis_name=None, + reduce_func=None, auto_infer=False): + """Vectorization compilation for class objects. + + Vectorized compile a function or a module to run in parallel on a single device. + + Examples + -------- + + Parameters + ---------- + func : BrainPyObject, function, callable + The function or the module to compile. + dyn_vars : dict, sequence + batched_vars : dict + in_axes : optional, int, sequence of int + Specify which input array axes to map over. If each positional argument to + ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, + or a tuple of integers and Nones with length equal to the number of + positional arguments to ``obj_or_func``. An integer or ``None`` + indicates which array axis to map over for all arguments (with ``None`` + indicating not to map any axis), and a tuple indicates which axis to map + for each corresponding positional argument. Axis integers must be in the + range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of + dimensions (axes) of the corresponding input array. + + If the positional arguments to ``obj_or_func`` are container types, the + corresponding element of ``in_axes`` can itself be a matching container, + so that distinct array axes can be mapped for different container + elements. ``in_axes`` must be a container tree prefix of the positional + argument tuple passed to ``obj_or_func``. + + At least one positional argument must have ``in_axes`` not None. The sizes + of the mapped input axes for all mapped positional arguments must all be + equal. + + Arguments passed as keywords are always mapped over their leading axis + (i.e. axis index 0). + out_axes : optional, int, tuple/list/dict + Indicate where the mapped axis should appear in the output. All outputs + with a mapped axis must have a non-None ``out_axes`` specification. Axis + integers must be in the range ``[-ndim, ndim)`` for each output array, + where ``ndim`` is the number of dimensions (axes) of the array returned + by the :func:`vmap`-ed function, which is one more than the number of + dimensions (axes) of the corresponding array returned by ``obj_or_func``. + axis_name : optional + + Returns + ------- + obj_or_func : Any + Batched/vectorized version of ``obj_or_func`` with arguments that correspond to + those of ``obj_or_func``, but with extra array axes at positions indicated by + ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but + with extra array axes at positions indicated by ``out_axes``. + + """ + # if isinstance(func, DynamicalSystem): + # if len(func.steps): # DynamicalSystem has step functions + # + # # dynamical variables + # dyn_vars = (dyn_vars or func.vars().unique()) + # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector() + # for key, val in dyn_vars.items(): + # if isinstance(val, RandomState): + # rand_vars[key] = val + # else: + # dyn_vars[key] = val + # + # # in axes + # if in_axes is None: + # in_axes = {key: (None, 0) for key in func.steps.keys()} + # elif isinstance(in_axes, int): + # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()} + # elif isinstance(in_axes, (tuple, list)): + # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()} + # elif isinstance(in_axes, dict): + # keys = list(func.steps.keys()) + # if keys[0] not in in_axes: + # in_axes = {key: (None, 0, in_axes) for key in keys} + # else: + # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys} + # assert isinstance(in_axes, dict) + # + # # batch size index + # batch_idx = {} + # for key, axes in in_axes.items(): + # for i, axis in enumerate(axes[2:]): + # if axis is not None: + # batch_idx[key] = (i, axis) + # break + # else: + # raise ValueError(f'Found no batch axis: {axes}.') + # + # # out axes + # if out_axes is None: + # out_axes = {key: 0 for key in func.steps.keys()} + # elif isinstance(out_axes, int): + # out_axes = {key: out_axes for key in func.steps.keys()} + # elif isinstance(out_axes, (tuple, list)): + # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()} + # elif isinstance(out_axes, dict): + # keys = list(func.steps.keys()) + # if keys[0] not in out_axes: + # out_axes = {key: (out_axes, 0, 0) for key in keys} + # else: + # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys} + # assert isinstance(out_axes, dict) + # + # # reduce_func + # if reduce_func is None: + # reduce_func = lambda x: x.mean(axis=0) + # + # # vectorized map functions + # for key in func.steps.keys(): + # func.steps[key] = _make_vmap(func=func.steps[key], + # dyn_vars=dyn_vars, + # rand_vars=rand_vars, + # in_axes=in_axes[key], + # out_axes=out_axes[key], + # axis_name=axis_name, + # batch_idx=batch_idx[key], + # reduce_func=reduce_func, + # f_name=key) + # + # return func + + if callable(func): + if auto_infer: + if dyn_vars is not None: + dyn_vars = dyn_vars + elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation + dyn_vars = func.vars().unique() + elif hasattr(func, '__self__'): + if isinstance(func.__self__, BrainPyObject): + dyn_vars = func.__self__.vars().unique() + + if dyn_vars is None: + return jax.vmap(func, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name) + + else: + if isinstance(dyn_vars, Array): + dyn_vars = [dyn_vars] + if isinstance(dyn_vars, (tuple, list)): + dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} + assert isinstance(dyn_vars, dict) + + # dynamical variables + _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector() + for key, val in dyn_vars.items(): + if isinstance(val, RandomState): + _rand_vars[key] = val + else: + _dyn_vars[key] = val + + # in axes + if in_axes is None: + in_axes = (None, 0) + elif isinstance(in_axes, (int, dict)): + in_axes = (None, 0, in_axes) + elif isinstance(in_axes, (tuple, list)): + in_axes = (None, 0) + tuple(in_axes) + assert isinstance(in_axes, (tuple, list)) + + # batch size index + batch_idx = {} + for key, axes in batch_idx.items(): + for i, axis in enumerate(axes[2:]): + if axis is not None: + batch_idx[key] = (i, axis) + break + else: + raise ValueError(f'Found no batch axis: {axes}.') + + # out axes + if out_axes is None: + out_axes = 0 + elif isinstance(out_axes, (int, dict)): + out_axes = (out_axes, 0, 0) + elif isinstance(out_axes, (tuple, list)): + out_axes = tuple(out_axes) + (0, 0) + assert isinstance(out_axes, (list, tuple)) + + # reduce_func + if reduce_func is None: + reduce_func = lambda x: x.mean(axis=0) + + # jit function + return _make_vmap(func=func, + nonbatched_vars=_dyn_vars, + batched_vars=_rand_vars, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + batch_idx=batch_idx) + + else: + raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable ' + f'function, but we got {type(func)}.') + + +def _device_reshape(x): + """Reshape an input array in order to broadcast to multiple devices.""" + num_device = jax.local_device_count() + + if not hasattr(x, 'ndim'): + raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to ' + f'parallel, first convert it to a Array, for example np.float(0.5)') + if x.ndim == 0: + return np.broadcast_to(x, [num_device]) + if x.shape[0] % num_device != 0: + raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among ' + f'{num_device} devices, but does not go equally.') + return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:]) + + +def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0, + out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, + axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None): + @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, + static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, + backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes) + def pmapped_func(dyn_data, rand_data, *args, **kwargs): + dyn_vars.assign(dyn_data) + rand_vars.assign(rand_data) + out = func(*args, **kwargs) + dyn_changes = dyn_vars.dict() + rand_changes = rand_vars.dict() + return out, dyn_changes, rand_changes + + def call(*args): + un_replicated = [k for k, v in dyn_vars.items() + if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))] + if len(un_replicated): + raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.' + f'did you forget to call xx.replicate() on them?') + _args = [] + for i, x in enumerate(args): + if i + 2 in static_broadcasted_argnums: + _args.append(x) + else: + _args.append(jax.tree_map(_device_reshape, [x])[0]) + dyn_data = dyn_vars.dict() + rand_data = rand_vars.dict() + output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args) + dyn_vars.assign(dyn_changes) + rand_vars.assign(rand_changes) + return jax.tree_map(reduce_func, output) + + return change_func_name(name=f_name, f=call) if f_name else call + + +def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), + devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, + reduce_func=None): + """Parallel compilation for class objects. + + Parallel compile a function or a module to run on multiple devices in parallel. + + Parameters + ---------- + func + axis_name + in_axes + out_axes + static_broadcasted_argnums + devices + backend + axis_size + donate_argnums + global_arg_shapes + + Returns + ------- + + + Examples + -------- + + + """ + + # if isinstance(func, DynamicalSystem): + # if len(func.steps): # DynamicalSystem has step functions + # + # # dynamical variables + # all_vars = (dyn_vars or func.vars().unique()) + # dyn_vars = ArrayCollector() + # rand_vars = ArrayCollector() + # for key, val in all_vars.items(): + # if isinstance(val, RandomState): + # rand_vars[key] = val + # else: + # dyn_vars[key] = val + # + # # reduce function + # if reduce_func is None: + # reduce_func = jnp.concatenate + # + # # static broadcast-ed arguments + # if static_broadcasted_argnums is None: + # static_broadcasted_argnums = () + # elif isinstance(static_broadcasted_argnums, int): + # static_broadcasted_argnums = (static_broadcasted_argnums + 2,) + # elif isinstance(static_broadcasted_argnums, (tuple, list)): + # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) + # assert isinstance(static_broadcasted_argnums, (tuple, list)) + # + # # jit functions + # for key in func.steps.keys(): + # step = func.steps[key] + # func.steps[key] = _make_pmap(dyn_vars=dyn_vars, + # rand_vars=rand_vars, + # func=step, + # axis_name=axis_name, + # in_axes=in_axes, + # out_axes=out_axes, + # static_broadcasted_argnums=static_broadcasted_argnums, + # devices=devices, + # backend=backend, + # axis_size=axis_size, + # donate_argnums=donate_argnums, + # global_arg_shapes=global_arg_shapes, + # reduce_func=reduce_func, + # f_name=key) + # return func + + if callable(func): + if dyn_vars is not None: + dyn_vars = dyn_vars + elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation + dyn_vars = func.vars().unique() + elif hasattr(func, '__self__'): + if isinstance(func.__self__, BrainPyObject): + dyn_vars = func.__self__.vars().unique() + + if dyn_vars is None: + return jax.pmap(func, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes) + else: + # dynamical variables + dyn_vars = ArrayCollector() + rand_vars = ArrayCollector() + for key, val in dyn_vars.items(): + if isinstance(val, RandomState): + rand_vars[key] = val + else: + dyn_vars[key] = val + + # static broadcast-ed arguments + if static_broadcasted_argnums is None: + static_broadcasted_argnums = () + elif isinstance(static_broadcasted_argnums, int): + static_broadcasted_argnums = (static_broadcasted_argnums + 2,) + elif isinstance(static_broadcasted_argnums, (tuple, list)): + static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) + assert isinstance(static_broadcasted_argnums, (tuple, list)) + + # reduce function + if reduce_func is None: + reduce_func = jnp.concatenate + + # jit function + func.__call__ = _make_pmap(dyn_vars=dyn_vars, + rand_vars=rand_vars, + func=func, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + reduce_func=reduce_func) + return func + + else: + raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, ' + f'but we got {type(func)}.') diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 632c6d79e..7b519590a 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -132,65 +132,19 @@ def evaluate_dyn_vars_with_cache( return stack -def _partial_fun2( - fun: Callable, - args: tuple, - kwargs: dict, - static_argnums: Sequence[int] = (), - static_argnames: Sequence[str] = () -): - num_args = len(args) - - # arguments - static_args = dict() - dyn_args = [] - dyn_arg_ids = dict() - static_argnums = list(static_argnums) - dyn_i = 0 - for i in range(num_args): - if i in static_argnums: - static_argnums.remove(i) - static_args[i] = args[i] - else: - dyn_args.append(args[i]) - dyn_arg_ids[i] = dyn_i - dyn_i += 1 - if len(static_argnums) > 0: - raise ValueError(f"Invalid static_argnums: {static_argnums}") - - # keyword arguments - static_kwargs, dyn_kwargs = {}, {} - for k, arg in kwargs.items(): - if k in static_argnames: - static_kwargs[k] = arg - else: - dyn_kwargs[k] = arg - del args, kwargs, static_argnums, static_argnames - - @wraps(fun) - def new_fun(*dynargs, **dynkwargs): - return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], - **static_kwargs, - **dynkwargs) - - return new_fun, dyn_args, dyn_kwargs - - def eval_shape( fun: Callable, *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), - with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. Args: fun: The callable function. - *args: The positional arguments. - **kwargs: The keyword arguments. - with_stack: Whether evaluate the function within a local variable stack. + *args: + **kwargs: static_argnums: The static argument indices. static_argnames: The static argument names. @@ -199,30 +153,21 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + f2, args, kwargs = _partial_fun(fun, args, kwargs, + static_argnums=static_argnums, + static_argnames=static_argnames) else: - f2 = fun + f2, args, kwargs = fun, args, kwargs # evaluate the function fun_in_eval_shape.append(fun) try: - if with_stack: + with jax.ensure_compile_time_eval(): with VariableStack() as stack: if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) + returns = fun(*args, **kwargs) else: - returns = jax.eval_shape(f2, *args, **kwargs) - else: - stack = None - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) - else: - returns = jax.eval_shape(f2, *args, **kwargs) + returns = jax.eval_shape(fun, *args, **kwargs) finally: fun_in_eval_shape.pop() - del f2 - if with_stack: - return stack, returns - else: - return returns - + return stack, returns diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index b7babae8d..5014da0bf 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple import jax @@ -189,14 +190,6 @@ def remove_by_id(self, *ids, error_when_absent=False): remove_var_by_id = remove_by_id - @classmethod - def num_of_stack(self): - return len(var_stack_list) - - @classmethod - def is_first_stack(self): - return len(var_stack_list) == 0 - def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) @@ -217,6 +210,42 @@ def __add__(self, other: dict): var_stack_list: List[VariableStack] = [] +transform_stack: List[Callable] = [] + + +@contextmanager +def new_transform(transform: Any): + transform_stack.append(transform) + try: + yield + finally: + transform_stack.pop() + + +def outermost_stack(): + if len(var_stack_list): + return var_stack_list[0] + else: + return None + + +def outermost_transform(): + if len(transform_stack): + return transform_stack[0] + else: + return None + + +def current_transform_number(): + return len(transform_stack) + + +def _stack_add_read(var: 'Variable'): + pass + + +def _stack_add_write(var: 'Variable'): + pass @register_pytree_node_class diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index 01f77dbca..ed687eea5 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -1,7 +1,8 @@ - -from .numba_approach import (CustomOpByNumba, - register_op_with_numba, - compile_cpu_signature_with_numba) -from .taichi_aot_based import clean_caches, check_kernels_count -from .base import XLACustomOp -from .utils import register_general_batching +from .numba_approach import (CustomOpByNumba, + register_op_with_numba, + compile_cpu_signature_with_numba) +from .base import XLACustomOp +from .utils import register_general_batching +from .taichi_aot_based import clean_caches, check_kernels_count +from .base import XLACustomOp +from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 1824ac911..ca070a197 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -4,8 +4,8 @@ import jax import numpy as np from jax.interpreters import xla, batching, ad, mlir -from numba.core.dispatcher import Dispatcher +from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject @@ -20,6 +20,8 @@ from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp +numba = import_numba(error_if_not_found=False) + __all__ = [ 'XLACustomOp', ] @@ -104,24 +106,30 @@ def __init__( self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) # cpu function + cpu_checked = False if cpu_kernel is None: - pass - elif isinstance(cpu_kernel, Dispatcher): # numba - register_numba_cpu_translation_rule(self.primitive, cpu_kernel) - elif hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi + cpu_checked = True + if numba is not None: # numba + from numba.core.dispatcher import Dispatcher + if isinstance(cpu_kernel, Dispatcher): + register_numba_cpu_translation_rule(self.primitive, cpu_kernel) + cpu_checked = True + if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi register_taichi_cpu_translation_rule(self.primitive, cpu_kernel) - else: + cpu_checked = True + if not cpu_checked: raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. ' f'But we got {cpu_kernel}') # gpu function + gpu_checked = False if gpu_kernel is None: - pass - elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi + gpu_checked = True + if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) - else: - raise ValueError(f'"cpu_kernel" must be a taichi kernel function. ' - f'But we got {gpu_kernel}') + gpu_checked = True + if not gpu_checked: + raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index cc2ce5b4c..5bbd04e0c 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -1,20 +1,22 @@ # -*- coding: utf-8 -*- -import warnings from functools import partial from typing import Callable from typing import Union, Sequence -import numba import jax from jax.interpreters import xla, batching, ad from jax.tree_util import tree_map -from numba.core.dispatcher import Dispatcher +from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject +from brainpy.errors import PackageMissingError from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba +numba = import_numba(error_if_not_found=False) + + __all__ = [ 'CustomOpByNumba', 'register_op_with_numba', @@ -137,6 +139,9 @@ def register_op_with_numba( f'For more information, please refer to the documentation: ' f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') + if numba is None: + raise PackageMissingError.by_purpose('numba', 'custom op with numba') + if out_shapes is None: raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' @@ -146,6 +151,7 @@ def register_op_with_numba( prim.multiple_results = multiple_results # user defined function + from numba.core.dispatcher import Dispatcher if not isinstance(cpu_func, Dispatcher): cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) @@ -196,5 +202,3 @@ def abs_eval_rule(*input_shapes, **info): ad.primitive_transposes[prim] = transpose_translation return prim - - diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 13974b5b2..4b06effdf 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -1,146 +1,152 @@ -# -*- coding: utf-8 -*- - -import ctypes - -from jax import dtypes, numpy as jnp -from jax.core import ShapedArray -from jax.lib import xla_client -from numba import types, carray, cfunc - -__all__ = [ - 'compile_cpu_signature_with_numba' -] - -ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, # void* pointer - ctypes.c_char_p, # const char *name - ctypes.c_void_p, # PyCapsule_Destructor destructor -] -ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - - -def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): - target_name, inputs, input_shapes, xla_output_shapes = \ - compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) - return xla_client.ops.CustomCallWithLayout( - c, - target_name, - operands=inputs, - operand_shapes_with_layout=input_shapes, - shape_with_layout=xla_output_shapes, - ) - - -def _cpu_signature( - func, - input_dtypes, - input_shapes, - output_dtypes, - output_shapes, - multiple_results: bool, - debug: bool = False -): - code_scope = dict( - func_to_call=func, - input_shapes=input_shapes, - input_dtypes=input_dtypes, - output_shapes=output_shapes, - output_dtypes=output_dtypes, - carray=carray, - ) - - # inputs - if len(input_shapes) > 1: - args_in = [ - f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' - for i in range(len(input_shapes)) - ] - args_in = '(\n ' + "\n ".join(args_in) + '\n )' - else: - args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' - - # outputs - if multiple_results: - args_out = [ - f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' - for i in range(len(output_shapes)) - ] - args_out = '(\n ' + "\n ".join(args_out) + '\n )' - else: - args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' - - # function body - code_string = ''' -def xla_cpu_custom_call_target(output_ptrs, input_ptrs): - args_out = {args_out} - args_in = {args_in} - func_to_call(args_out, args_in) - '''.format(args_in=args_in, - args_out=args_out) - if debug: print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - - new_f = code_scope['xla_cpu_custom_call_target'] - if multiple_results: - xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), - types.CPointer(types.voidptr)))(new_f) - else: - xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) - target_name = xla_c_rule.native_name.encode("ascii") - capsule = ctypes.pythonapi.PyCapsule_New( - xla_c_rule.address, # A CFFI pointer to a function - b"xla._CUSTOM_CALL_TARGET", # A binary string - None # PyCapsule object run at destruction - ) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - return target_name - - -def compile_cpu_signature_with_numba( - c, - func, - abs_eval_fn, - multiple_results, - inputs: tuple, - description: dict = None, -): - input_layouts = [c.get_shape(arg) for arg in inputs] - info_inputs = [] - if description is None: description = dict() - for v in description.values(): - if isinstance(v, (int, float)): - input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) - info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) - elif isinstance(v, (tuple, list)): - v = jnp.asarray(v) - input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) - info_inputs.append(xla_client.ops.Constant(c, v)) - else: - raise TypeError - input_layouts = tuple(input_layouts) - input_dtypes = tuple(shape.element_type() for shape in input_layouts) - input_dimensions = tuple(shape.dimensions() for shape in input_layouts) - output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_layouts[:len(inputs)]), - **description) - if isinstance(output_abstract_arrays, ShapedArray): - output_abstract_arrays = (output_abstract_arrays,) - assert not multiple_results - else: - assert multiple_results - output_shapes = tuple(array.shape for array in output_abstract_arrays) - output_dtypes = tuple(array.dtype for array in output_abstract_arrays) - output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) - target_name = _cpu_signature(func, - input_dtypes, - input_dimensions, - output_dtypes, - output_shapes, - multiple_results, - debug=False) - output_layouts = [xla_client.Shape.array_shape(*arg) - for arg in zip(output_dtypes, output_shapes, output_layouts)] - output_layouts = (xla_client.Shape.tuple_shape(output_layouts) - if multiple_results else - output_layouts[0]) - return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts +# -*- coding: utf-8 -*- + +import ctypes + +from jax import dtypes, numpy as jnp +from jax.core import ShapedArray +from jax.lib import xla_client + +from brainpy._src.dependency_check import import_numba + +numba = import_numba(error_if_not_found=False) +ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, # void* pointer + ctypes.c_char_p, # const char *name + ctypes.c_void_p, # PyCapsule_Destructor destructor +] +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + +__all__ = [ + '_cpu_translation', + 'compile_cpu_signature_with_numba', +] + +if numba is not None: + from numba import types, carray, cfunc + + +def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): + target_name, inputs, input_shapes, xla_output_shapes = \ + compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) + return xla_client.ops.CustomCallWithLayout( + c, + target_name, + operands=inputs, + operand_shapes_with_layout=input_shapes, + shape_with_layout=xla_output_shapes, + ) + + +def _cpu_signature( + func, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + multiple_results: bool, + debug: bool = False +): + code_scope = dict( + func_to_call=func, + input_shapes=input_shapes, + input_dtypes=input_dtypes, + output_shapes=output_shapes, + output_dtypes=output_dtypes, + carray=carray, + ) + + # inputs + if len(input_shapes) > 1: + args_in = [ + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' + for i in range(len(input_shapes)) + ] + args_in = '(\n ' + "\n ".join(args_in) + '\n )' + else: + args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' + + # outputs + if multiple_results: + args_out = [ + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' + for i in range(len(output_shapes)) + ] + args_out = '(\n ' + "\n ".join(args_out) + '\n )' + else: + args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' + + # function body + code_string = ''' +def xla_cpu_custom_call_target(output_ptrs, input_ptrs): + args_out = {args_out} + args_in = {args_in} + func_to_call(args_out, args_in) + '''.format(args_in=args_in, + args_out=args_out) + if debug: print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + + new_f = code_scope['xla_cpu_custom_call_target'] + if multiple_results: + xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), + types.CPointer(types.voidptr)))(new_f) + else: + xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) + target_name = xla_c_rule.native_name.encode("ascii") + capsule = ctypes.pythonapi.PyCapsule_New( + xla_c_rule.address, # A CFFI pointer to a function + b"xla._CUSTOM_CALL_TARGET", # A binary string + None # PyCapsule object run at destruction + ) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + return target_name + + +def compile_cpu_signature_with_numba( + c, + func, + abs_eval_fn, + multiple_results, + inputs: tuple, + description: dict = None, +): + input_layouts = [c.get_shape(arg) for arg in inputs] + info_inputs = [] + if description is None: description = dict() + for v in description.values(): + if isinstance(v, (int, float)): + input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) + info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) + elif isinstance(v, (tuple, list)): + v = jnp.asarray(v) + input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) + info_inputs.append(xla_client.ops.Constant(c, v)) + else: + raise TypeError + input_layouts = tuple(input_layouts) + input_dtypes = tuple(shape.element_type() for shape in input_layouts) + input_dimensions = tuple(shape.dimensions() for shape in input_layouts) + output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) + for shape in input_layouts[:len(inputs)]), + **description) + if isinstance(output_abstract_arrays, ShapedArray): + output_abstract_arrays = (output_abstract_arrays,) + assert not multiple_results + else: + assert multiple_results + output_shapes = tuple(array.shape for array in output_abstract_arrays) + output_dtypes = tuple(array.dtype for array in output_abstract_arrays) + output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) + target_name = _cpu_signature(func, + input_dtypes, + input_dimensions, + output_dtypes, + output_shapes, + multiple_results, + debug=False) + output_layouts = [xla_client.Shape.array_shape(*arg) + for arg in zip(output_dtypes, output_shapes, output_layouts)] + output_layouts = (xla_client.Shape.tuple_shape(output_layouts) + if multiple_results else + output_layouts[0]) + return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index fb76aed24..f461f4277 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -6,17 +6,20 @@ from jax.interpreters import xla, mlir from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call -from numba import types, carray, cfunc +from brainpy._src.dependency_check import import_numba +from brainpy.errors import PackageMissingError from .utils import _shape_to_layout +numba = import_numba(error_if_not_found=False) +if numba is not None: + from numba import types, carray, cfunc __all__ = [ 'register_numba_xla_cpu_translation_rule', 'register_numba_mlir_cpu_translation_rule', ] - # [void* pointer, # const char *name, # PyCapsule_Destructor destructor] @@ -104,6 +107,9 @@ def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs): def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False): + if numba is None: + raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') + # do not support after jax >= 0.4.24 xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule, cpu_kernel, @@ -168,5 +174,8 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs): def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False): + if numba is None: + raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') + rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug) mlir.register_lowering(primitive, rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py index 24f010a12..2c9f09724 100644 --- a/brainpy/_src/math/op_register/tests/test_ad_support.py +++ b/brainpy/_src/math/op_register/tests/test_ad_support.py @@ -1,13 +1,18 @@ +import pytest from typing import Tuple import jax -import numba from jax import core from jax import numpy as jnp from jax.interpreters import ad import brainpy as bp import brainpy.math as bm +from brainpy._src.dependency_check import import_numba + +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py index 968155ef9..dc093f624 100644 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ b/brainpy/_src/math/op_register/tests/test_numba_based.py @@ -1,6 +1,11 @@ +import pytest import jax.core import brainpy.math as bm -import numba + +from brainpy._src.dependency_check import import_numba +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 03023754c..4db38fbcb 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -1,9 +1,14 @@ +import pytest import jax import jax.numpy as jnp -import taichi as ti import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 1bebcdafe..51c964b29 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -1,54 +1,58 @@ -import brainpy.math as bm -import jax -import jax.numpy as jnp -import platform -import pytest -import taichi - -if not platform.platform().startswith('Windows'): - pytest.skip(allow_module_level=True) - -@taichi.func -def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: - return weight[0] - - -@taichi.func -def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): - out[index] += weight_val - -@taichi.kernel -def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), - vector: taichi.types.ndarray(ndim=1), - weight: taichi.types.ndarray(ndim=1), - out: taichi.types.ndarray(ndim=1)): - weight_val = get_weight(weight) - num_rows, num_cols = indices.shape - taichi.loop_config(serialize=True) - for i in range(num_rows): - if vector[i]: - for j in range(num_cols): - update_output(out, indices[i, j], weight_val) - -prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) - -def test_taichi_clean_cache(): - s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) - vector = bm.random.rand(s) < 0.1 - weight = bm.array([1.0]) - - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - print(out) - bm.clear_buffer_memory() - - print('kernels: ', bm.check_kernels_count()) - - bm.clean_caches() - - print('kernels: ', bm.check_kernels_count()) - +import brainpy.math as bm +import jax +import jax.numpy as jnp +import platform +import pytest + +from brainpy._src.dependency_check import import_taichi +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + +if not platform.platform().startswith('Windows'): + pytest.skip(allow_module_level=True) + +@ti.func +def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: + return weight[0] + + +@ti.func +def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): + out[index] += weight_val + +@ti.kernel +def event_ell_cpu(indices: ti.types.ndarray(ndim=2), + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + weight_val = get_weight(weight) + num_rows, num_cols = indices.shape + ti.loop_config(serialize=True) + for i in range(num_rows): + if vector[i]: + for j in range(num_cols): + update_output(out, indices[i, j], weight_val) + +prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) + +def test_taichi_clean_cache(): + s = 1000 + indices = bm.random.randint(0, s, (s, 1000)) + vector = bm.random.rand(s) < 0.1 + weight = bm.array([1.0]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + print(out) + bm.clear_buffer_memory() + + print('kernels: ', bm.check_kernels_count()) + + bm.clean_caches() + + print('kernels: ', bm.check_kernels_count()) + # test_taichi_clean_cache() \ No newline at end of file diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d45f2c80b..d53533247 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,8 +1,7 @@ - -from ._coo_mv import * +# from ._coo_mv import * +# from ._bsr_mv import * from ._csr_mv import * from ._utils import * -from ._bsr_mv import * from ._bsr_mm import * from ._jax_prim import * diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index 453ab387d..19800749d 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -1,22 +1,23 @@ # -*- coding: utf-8 -*- from functools import partial -from typing import Union, Tuple +from typing import Tuple import jax.lax -import numba import numpy as np from jax import numpy as jnp from jax.core import Primitive, ShapedArray from jax.interpreters import ad, xla from jax.lib import xla_client +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba from brainpy._src.math.interoperability import as_jax -from brainpy._src.dependency_check import import_brainpylib_gpu_ops from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) from brainpy.errors import GPUOperatorNotFound +numba = import_numba(error_if_not_found=False) + __all__ = [ 'bcsrmm', ] @@ -264,52 +265,53 @@ def bcsrmm( raise ValueError -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m) - res_val = outs[0] - # B_data: (num_block, block_size_k, block_size_n) - A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins - block_size_k = block_size_k[()] - block_size_n = block_size_n[()] - n_block = n // block_size_n - - for ni in numba.prange(n_block): - C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) - start, end = B_inptr[ni], B_inptr[ni + 1] - ns = ni * block_size_n - ne = ns + block_size_n - for i in range(start, end): - ki = B_indices[i, 0] - ks = ki * block_size_k - ke = ki + block_size_k - bi = B_indices[i, 1] - C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) - res_val[ns: ne] = C_tmp - return res_val - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m) - res_val = outs[0] - # B_data: (num_block, block_size_n, block_size_k) - A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins - block_size_k = block_size_k[()] - block_size_n = block_size_n[()] - n_block = n // block_size_n - - for ni in numba.prange(n_block): - C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) - start, end = B_inptr[ni], B_inptr[ni + 1] - ns = ni * block_size_n - ne = ns + block_size_n - for i in range(start, end): - ki = B_indices[i, 0] - ks = ki * block_size_k - ke = ki + block_size_k - bi = B_indices[i, 1] - C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) - res_val[ns: ne] = C_tmp - return res_val +if numba is not None: + @numba.njit(fastmath=True, parallel=True, nogil=True) + def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m) + res_val = outs[0] + # B_data: (num_block, block_size_k, block_size_n) + A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins + block_size_k = block_size_k[()] + block_size_n = block_size_n[()] + n_block = n // block_size_n + + for ni in numba.prange(n_block): + C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) + start, end = B_inptr[ni], B_inptr[ni + 1] + ns = ni * block_size_n + ne = ns + block_size_n + for i in range(start, end): + ki = B_indices[i, 0] + ks = ki * block_size_k + ke = ki + block_size_k + bi = B_indices[i, 1] + C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) + res_val[ns: ne] = C_tmp + return res_val + + + @numba.njit(fastmath=True, parallel=True, nogil=True) + def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m) + res_val = outs[0] + # B_data: (num_block, block_size_n, block_size_k) + A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins + block_size_k = block_size_k[()] + block_size_n = block_size_n[()] + n_block = n // block_size_n + + for ni in numba.prange(n_block): + C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) + start, end = B_inptr[ni], B_inptr[ni + 1] + ns = ni * block_size_n + ne = ns + block_size_n + for i in range(start, end): + ki = B_indices[i, 0] + ks = ki * block_size_k + ke = ki + block_size_k + bi = B_indices[i, 1] + C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) + res_val[ns: ne] = C_tmp + return res_val def _bcsrmm_cutlass_abstract( diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index 377597579..42969f435 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -1,28 +1,22 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Union, Tuple import jax -import numba -import numpy as np -from jax import core, dtypes from jax import numpy as jnp -from jax.interpreters import ad, mlir, xla -from jax.lib import xla_client -from jaxlib import gpu_sparse +from jax.experimental.sparse import csr +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, - XLACustomOp) +from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy.errors import GPUOperatorNotFound +from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'csrmv', @@ -37,7 +31,6 @@ def csrmv( *, shape: Tuple[int, int], transpose: bool = False, - method: str = None, ): """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. @@ -70,495 +63,6 @@ def csrmv( - ``vector``: - ``adaptive``: - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - if method is None: - return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) - else: - return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method) - - -### BRAINPYLIB ### - -def csrmv_brainpylib( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, - method: str = 'cusparse', -): - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - method: str - The method used to compute Matrix-Vector Multiplication. The candidate methods are: - - - ``cusparse``: using cuSPARSE library. - - ``scalar``: - - ``vector``: - - ``adaptive``: - - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - - data = jnp.atleast_1d(as_jax(data)) - indices = as_jax(indices) - indptr = as_jax(indptr) - vector = as_jax(vector) - - if vector.dtype == jnp.bool_: - vector = as_jax(vector, dtype=data.dtype) - - if method == 'cusparse': - if jax.default_backend() == 'gpu': - if data.shape[0] == 1: - data = jnp.ones(indices.shape, dtype=data.dtype) * data - if indices.dtype in [jnp.uint32, jnp.uint64]: - indices = jnp.asarray(indices, dtype=dtypes.canonicalize_dtype(jnp.int64)) - if indptr.dtype in [jnp.uint32, jnp.uint64]: - indptr = jnp.asarray(indptr, dtype=dtypes.canonicalize_dtype(jnp.int64)) - return _csrmv_cusparse_p.bind(data, - indices, - indptr, - vector, - shape=shape, - transpose=transpose) - - elif method == 'adaptive': - return _csrmv_adaptive_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - elif method == 'scalar': - return _csrmv_scalar_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - elif method == 'vector': - return _csrmv_vector_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - else: - raise ValueError(f'Only support methods: cusparse, scalar, vector, and adaptive. But we got {method}.') - - -def _csrmv_abstract(data, indices, indptr, vector, *, shape, transpose): - if data.dtype not in [jnp.float32, jnp.float64]: - raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.') - if data.dtype != vector.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {vector.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - out_shape = shape[1] if transpose else shape[0] - return core.ShapedArray((out_shape,), data.dtype) - - -@numba.njit(fastmath=True) -def _csr_matvec_transpose_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, col_indices, row_ptr, vector, shape, _ = ins - # (csr mat).T @ vec - - if values.shape[0] == 1: - values = values[0] - for row_i in range(shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += values * v - else: - for row_i in range(shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += v * values[j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _csr_matvec_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, col_indices, row_ptr, vector, shape, _ = ins - # csr mat @ vec - if values.shape[0] == 1: - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values * vector[col_indices[j]] - res_val[row_i] = r - else: - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - res_val[row_i] = r - - -def _csrmv_cpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - inputs = (data, indices, indptr, vector) - description = dict(shape=shape, transpose=transpose) - if transpose: - target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba( - c, - _csr_matvec_transpose_numba_imp, - _csrmv_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - else: - target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba( - c, - _csr_matvec_numba_imp, - _csrmv_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, - target_name, - operands=inputs, - operand_shapes_with_layout=input_layouts, - shape_with_layout=output_layouts, - ) - - -def _csrmv_cusparse_gpu_lowering(ctx, data, indices, indptr, vector, *, shape, transpose): - data_aval, indices_aval, _, v_aval = ctx.avals_in - dtype = data_aval.dtype - if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: - raise TypeError(f"cusparse_csr_matvec cusparse/hipsparse lowering not available for dtype={dtype}. " - "Falling back to default implementation.") - return [gpu_sparse.cuda_csr_matvec(data, indices, indptr, vector, - shape=shape, - transpose=transpose, - data_dtype=dtype, - x_dtype=v_aval.dtype, - index_dtype=indices_aval.dtype)] - - -def _csrmv_jvp_mat(csr_prim, data_dot, data, indices, indptr, v, *, shape, transpose): - return csr_prim.bind(data_dot, indices, indptr, v, shape=shape, transpose=transpose) - - -def _csrmv_jvp_vec(prim, v_dot, data, indices, indptr, v, *, shape, transpose): - return prim.bind(data, indices, indptr, v_dot, shape=shape, transpose=transpose) - - -def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return data, indices, indptr, ad.Zero(vector) - else: - ct_vector = _csrmv_cusparse_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, ct_vector - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_cusparse_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_cusparse_p = core.Primitive('cusparse_csr_matvec') -_csrmv_cusparse_p.def_abstract_eval(_csrmv_abstract) -_csrmv_cusparse_p.def_impl(partial(xla.apply_primitive, _csrmv_cusparse_p)) -# xla.backend_specific_translations['cpu'][_csrmv_cusparse_p] = _csrmv_cpu_translation -ad.defjvp(_csrmv_cusparse_p, - partial(_csrmv_jvp_mat, _csrmv_cusparse_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_cusparse_p)) -ad.primitive_transposes[_csrmv_cusparse_p] = _csrmv_cusparse_transpose -register_general_batching(_csrmv_cusparse_p) -mlir.register_lowering(_csrmv_cusparse_p, _csrmv_cusparse_gpu_lowering, platform='cuda') - - -def _csr_matvec_scalar_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_scalar_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_scalar' + ftype + itype, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_scalar_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_scalar_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_scalar_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_scalar_p = core.Primitive('csr_matvec_scalar') -_csrmv_scalar_p.def_abstract_eval(_csrmv_abstract) -_csrmv_scalar_p.def_impl(partial(xla.apply_primitive, _csrmv_scalar_p)) -# xla.backend_specific_translations['cpu'][_csrmv_scalar_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_scalar_p] = _csr_matvec_scalar_gpu_translation -ad.defjvp(_csrmv_scalar_p, - partial(_csrmv_jvp_mat, _csrmv_scalar_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_scalar_p), ) -ad.primitive_transposes[_csrmv_scalar_p] = _csrmv_scalar_transpose -register_general_batching(_csrmv_scalar_p) - - -def _csr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_vector_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_vector' + ftype + itype, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_vector_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_vector_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_vector_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_vector_p = core.Primitive('csr_matvec_vector') -_csrmv_vector_p.def_abstract_eval(_csrmv_abstract) -_csrmv_vector_p.def_impl(partial(xla.apply_primitive, _csrmv_vector_p)) -# xla.backend_specific_translations['cpu'][_csrmv_vector_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_vector_p] = _csr_matvec_vector_gpu_translation -ad.defjvp(_csrmv_vector_p, - partial(_csrmv_jvp_mat, _csrmv_vector_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_vector_p), ) -ad.primitive_transposes[_csrmv_vector_p] = _csrmv_vector_transpose -register_general_batching(_csrmv_vector_p) - - -def _csr_matvec_adaptive_gpu_translation(c, data, indices, indptr, row_blocks, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_adaptive_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_vector' + ftype + itype, - operands=(data, indices, indptr, row_blocks, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(row_blocks), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_adaptive_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_adaptive_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_adaptive_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_adaptive_p = core.Primitive('csr_matvec_adaptive') -_csrmv_adaptive_p.def_abstract_eval(_csrmv_abstract) -_csrmv_adaptive_p.def_impl(partial(xla.apply_primitive, _csrmv_adaptive_p)) -# xla.backend_specific_translations['cpu'][_csrmv_adaptive_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_adaptive_p] = _csr_matvec_adaptive_gpu_translation -ad.defjvp(_csrmv_adaptive_p, - partial(_csrmv_jvp_mat, _csrmv_adaptive_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_adaptive_p), ) -ad.primitive_transposes[_csrmv_adaptive_p] = _csrmv_adaptive_transpose -register_general_batching(_csrmv_adaptive_p) - - -### TAICHI ### - -def csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -) -> jax.Array: - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - Returns ------- y : ndarry @@ -593,171 +97,6 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] -# ------------- -# CPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += value * vector[row_i] - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += vector[row_i] * values[j] - - -@ti.kernel -def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += vector[col_indices[j]] - out[row_i] = r * value - - -@ti.kernel -def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += value * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += vector[col_indices[j]] - j += 32 - out[row_i] += value * r - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += values[j] * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += values[j] * vector[col_indices[j]] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_transpose( - ct, data, indices, indptr, vector, *, outs, transpose, shape, -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(vector): - ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] - ct_data = jnp.inner(ct[0], ct_data) - else: - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] - - return ct_data, indices, indptr, vector - - def raw_csrmv_taichi( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], @@ -767,17 +106,22 @@ def raw_csrmv_taichi( shape: Tuple[int, int], transpose: bool = False, ): + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') out_shape = shape[1] if transpose else shape[0] - if transpose: - if data.shape[0] == 1: - prim = _csr_matvec_transpose_homo_p + if data.shape[0] != 1: + if bm.get_platform() == 'gpu': + return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)] else: - prim = _csr_matvec_transpose_heter_p + if transpose: + prim = _csr_matvec_transpose_heter_p + else: + prim = _csr_matvec_heter_p else: - if data.shape[0] == 1: - prim = _csr_matvec_homo_p + if transpose: + prim = _csr_matvec_transpose_homo_p else: - prim = _csr_matvec_heter_p + prim = _csr_matvec_homo_p return prim(data, indices, @@ -788,25 +132,193 @@ def raw_csrmv_taichi( shape=shape) -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) - prim.def_transpose_rule(_sparse_csr_matvec_transpose) - return prim +if ti is not None: + + # ------------- + # CPU operators + # ------------- + @ti.kernel + def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += value * vector[row_i] + + + @ti.kernel + def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += vector[row_i] * values[j] + + + @ti.kernel + def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += vector[col_indices[j]] + out[row_i] = r * value + + + @ti.kernel + def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + out[row_i] = r + + + # ------------- + # GPU operators + # ------------- + + @ti.kernel + def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += value * vector[row_i] + j += 32 + + + @ti.kernel + def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += vector[col_indices[j]] + j += 32 + out[row_i] += value * r + + + @ti.kernel + def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += values[j] * vector[row_i] + j += 32 + + + @ti.kernel + def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += values[j] * vector[col_indices[j]] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) + + + def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) + + + def _sparse_csr_matvec_transpose( + ct, data, indices, indptr, vector, *, outs, transpose, shape, + ): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(vector): + ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] + ct_data = jnp.inner(ct[0], ct_data) + else: + row, col = csr_to_coo(indices, indptr) + ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] + + return ct_data, indices, indptr, vector + + + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) + prim.def_transpose_rule(_sparse_csr_matvec_transpose) + return prim + + # transpose homo + _csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) -# transpose homo -_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) + # no transpose homo + _csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, + gpu_kernel=_sparse_csr_matvec_homo_gpu) -# no transpose homo -_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, - gpu_kernel=_sparse_csr_matvec_homo_gpu) + # transpose heter + _csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) -# transpose heter -_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) + # no transpose heter + _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, + gpu_kernel=_sparse_csr_matvec_heter_gpu) -# no transpose heter -_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) + # heter cusparse + _csr_matvec_cusparse_p = csr.csr_matvec_p + register_general_batching(_csr_matvec_cusparse_p) diff --git a/brainpy/_src/math/sparse/_utils.py b/brainpy/_src/math/sparse/_utils.py index a1dc9190e..f5b74e5eb 100644 --- a/brainpy/_src/math/sparse/_utils.py +++ b/brainpy/_src/math/sparse/_utils.py @@ -3,9 +3,8 @@ import warnings from typing import Tuple -import jax import numpy as np -from jax import core, numpy as jnp, dtypes +from jax import core, numpy as jnp from jax.interpreters import mlir, ad from jaxlib import gpu_sparse diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 2c75f0901..ec448e658 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -5,10 +5,14 @@ import jax from absl.testing import parameterized +import pytest import brainpy as bp import brainpy.math as bm -# bm.set_platform('gpu') +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) seed = 1234 diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_old.py b/brainpy/_src/math/sparse/tests/test_csrmv_old.py deleted file mode 100644 index b73217496..000000000 --- a/brainpy/_src/math/sparse/tests/test_csrmv_old.py +++ /dev/null @@ -1,352 +0,0 @@ -# -*- coding: utf-8 -*- - -from functools import partial - -import jax -import pytest -from absl.testing import parameterized -import platform -import brainpy as bp -import brainpy.math as bm - -pytest.skip('Old implementation.', allow_module_level=True) - -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - -cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') -scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') -vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') - - -class Test_cusparse_csrmv(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_cusparse_csrmv, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = bm.ones(indices.shape).value * homo_data - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r3)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - v=[-1., 0., 1.] - ) - def test_homo_vmap(self, transpose, shape, v): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = bm.ones((10, indices.shape[0])).value * v - homo_data = bm.ones(10).value * v - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - - r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(heter_data) - self.assertTrue(bm.allclose(r1, r2)) - - r3 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r3)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, transpose=transpose).sum(), - argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() - if transpose else - ((dense * a) @ vector).sum()), - argnums=0) - - r1 = csr_f1(homo_data) - r2 = dense_f1(homo_data) - self.assertTrue(bm.allclose(r1, r2)) - - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v, - shape=shape, transpose=transpose).sum()) - dense_data = dense * homo_data - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) - - r3 = csr_f2(vector) - r4 = dense_f2(vector) - self.assertTrue(bm.allclose(r3, r4)) - - csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v, - shape=shape, transpose=transpose).sum(), - argnums=(0, 1)) - dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() - if transpose else - ((dense * a) @ v).sum()), - argnums=(0, 1)) - - r5 = csr_f3(homo_data, vector) - r6 = dense_f3(homo_data, vector) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - ) - def test_heter(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector, - shape=shape, transpose=transpose) - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r2 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_vmap(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = rng.random((10, indices.shape[0])) - heter_data = bm.as_jax(heter_data) - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=shape))(heter_data) - - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - - r1 = jax.vmap(f1)(heter_data) - r2 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), - argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), - argnums=0) - - r1 = csr_f1(heter_data) - r2 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r2 = r2[rows, cols] - self.assertTrue(bm.allclose(r1, r2)) - - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), - argnums=0) - r3 = csr_f2(vector) - r4 = dense_f2(vector) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - -class Test_csrmv(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmv, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - homo_data=[-1., 0., 0.1, 1.], - shape=[(100, 200), (10, 1000), (2, 2000)], - ) - def test_homo(self, shape, homo_data): - conn = bp.conn.FixedProb(0.1) - - # matrix - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # vector - rng = bm.random.RandomState(123) - vector = rng.random(shape[1]) - vector = bm.as_jax(vector) - - # csrmv - r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - - heter_data = bm.ones(indices.shape).to_jax() * homo_data - r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r4)) - self.assertTrue(bm.allclose(r1, r5)) - self.assertTrue(bm.allclose(r1, r6)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - rdense = dense @ vector - self.assertTrue(bm.allclose(r1, rdense)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = bm.as_jax(rng.random(indices.shape)) - vector = bm.as_jax(rng.random(shape[1])) - - r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = dense @ vector - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - heter_data = rng.random(indices.shape) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[1]) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - dense_f1 = jax.grad(lambda a: (a @ vector).sum()) - - r1 = csr_f1(heter_data) - r2 = csr_f2(heter_data) - r3 = csr_f3(heter_data) - - d1 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - d1 = d1[rows, cols] - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, d1)) - - # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum()) - # r4 = csr_f4(vector) - # r5 = csr_f5(vector) - # r6 = csr_f6(vector) - # d2 = dense_f2(vector) - # self.assertTrue(bm.allclose(r4, r5)) - # self.assertTrue(bm.allclose(r4, r6)) - # self.assertTrue(bm.allclose(r4, d2)) - - bm.clear_buffer_memory() - - diff --git a/brainpy/_src/math/tests/test_tifunc.py b/brainpy/_src/math/tests/test_tifunc.py index 6823ebabd..db6e7debc 100644 --- a/brainpy/_src/math/tests/test_tifunc.py +++ b/brainpy/_src/math/tests/test_tifunc.py @@ -1,122 +1,124 @@ -# -*- coding: utf-8 -*- - -import jax -import jax.numpy as jnp -import pytest - -pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.") -import brainpy.math as bm -import taichi as ti -import matplotlib.pyplot as plt -import os - - -bm.set_platform('cpu') - - -def test_taichi_random(): - @ti.kernel - def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), - out: ti.types.ndarray(ndim=1, dtype=ti.f32)): - key = bm.tifunc.lfsr88_key(seed[0]) - for i in range(out.shape[0]): - key, result = bm.tifunc.lfsr88_rand(key) - out[i] = result - - @ti.kernel - def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range(out.shape[0]): - out[i] = bm.tifunc.taichi_lcg_rand(seed) - - @ti.kernel - def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) - - @ti.kernel - def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) - - @ti.kernel - def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), - mu_sigma: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - mu = mu_sigma[0] - sigma = mu_sigma[1] - - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) - - n = 100000 - seed = jnp.array([1234, ], dtype=jnp.uint32) - low_high = jnp.array([0, 10]) - mu_sigma = jnp.array([0, 1]) - - prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, - gpu_kernel=test_taichi_lfsr88) - - - prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, - gpu_kernel=test_taichi_lcg_rand) - prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, - gpu_kernel=test_taichi_uniform_int_distribution) - prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, - gpu_kernel=test_taichi_uniform_real_distribution) - prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, - gpu_kernel=test_taichi_normal_distribution) - - file_path = os.path.dirname(os.path.abspath(__file__)) - - out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LFSR88 random number generator") - plt.savefig(file_path + "/lfsr88.png") - plt.close() - - out = prim_lcg_rand(seed, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LCG random number generator") - plt.savefig(file_path + "/lcg_rand.png") - plt.close() - - out = prim_uniform_int_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) - # show the distribution of out - plt.hist(out, bins=10) - plt.title("Uniform int distribution (0, 10)") - plt.savefig(file_path + "/uniform_int_distribution.png") - plt.close() - - out = prim_uniform_real_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("Uniform real distribution (0, 10)") - plt.savefig(file_path + "/uniform_real_distribution.png") - plt.close() - - out = prim_normal_distribution(seed, mu_sigma, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.title("Normal distribution mu=0, sigma=1") - plt.hist(out, bins=100) - plt.savefig(file_path + "/normal_distribution.png") - - -# TODO; test default types +# -*- coding: utf-8 -*- + +import jax +import jax.numpy as jnp +import pytest + +pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.") +import brainpy.math as bm +import matplotlib.pyplot as plt +import os + +from brainpy._src.dependency_check import import_taichi + +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + +bm.set_platform('cpu') + + +def test_taichi_random(): + @ti.kernel + def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), + out: ti.types.ndarray(ndim=1, dtype=ti.f32)): + key = bm.tifunc.lfsr88_key(seed[0]) + for i in range(out.shape[0]): + key, result = bm.tifunc.lfsr88_rand(key) + out[i] = result + + @ti.kernel + def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range(out.shape[0]): + out[i] = bm.tifunc.taichi_lcg_rand(seed) + + @ti.kernel + def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), + low_high: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + low = low_high[0] + high = low_high[1] + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) + + @ti.kernel + def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), + low_high: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + low = low_high[0] + high = low_high[1] + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) + + @ti.kernel + def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), + mu_sigma: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + mu = mu_sigma[0] + sigma = mu_sigma[1] + + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) + + n = 100000 + seed = jnp.array([1234, ], dtype=jnp.uint32) + low_high = jnp.array([0, 10]) + mu_sigma = jnp.array([0, 1]) + + prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, + gpu_kernel=test_taichi_lfsr88) + + prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, + gpu_kernel=test_taichi_lcg_rand) + prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, + gpu_kernel=test_taichi_uniform_int_distribution) + prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, + gpu_kernel=test_taichi_uniform_real_distribution) + prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, + gpu_kernel=test_taichi_normal_distribution) + + file_path = os.path.dirname(os.path.abspath(__file__)) + + out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("LFSR88 random number generator") + plt.savefig(file_path + "/lfsr88.png") + plt.close() + + out = prim_lcg_rand(seed, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("LCG random number generator") + plt.savefig(file_path + "/lcg_rand.png") + plt.close() + + out = prim_uniform_int_distribution(seed, low_high, + outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) + # show the distribution of out + plt.hist(out, bins=10) + plt.title("Uniform int distribution (0, 10)") + plt.savefig(file_path + "/uniform_int_distribution.png") + plt.close() + + out = prim_uniform_real_distribution(seed, low_high, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("Uniform real distribution (0, 10)") + plt.savefig(file_path + "/uniform_real_distribution.png") + plt.close() + + out = prim_normal_distribution(seed, mu_sigma, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.title("Normal distribution mu=0, sigma=1") + plt.hist(out, bins=100) + plt.savefig(file_path + "/normal_distribution.png") + +# TODO; test default types diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py index a9ee39f4a..9cfd39e1a 100644 --- a/brainpy/_src/math/tifunc.py +++ b/brainpy/_src/math/tifunc.py @@ -1,7 +1,7 @@ -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, raise_taichi_not_found from . import defaults -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ # taichi function for other utilities @@ -16,349 +16,330 @@ 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', ] +if ti is not None: -@ti.func -def _lcg_rand(state: ti.types.ndarray(ndim=1)): - # LCG constants - state[0] = ti.u32(1664525) * state[0] + ti.u32(1013904223) - return state[0] + ############################################# + # Random Number Generator: LFSR88 algorithm # + ############################################# + @ti.func + def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): + """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). -@ti.func -def taichi_lcg_rand(seed: ti.types.ndarray(ndim=1)): - """ - Generate a random number using the Taichi LCG algorithm. + This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. - Parameters: - seed (ti.types.ndarray): The seed value for the random number generator. + Source: + https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c - Returns: - float: A random number between 0 and 1. - """ + /**** VERY IMPORTANT **** : + The initial seeds s1, s2, s3 MUST be larger than + 1, 7, and 15 respectively. + */ - return float(_lcg_rand(seed)) / ti.u32(2 ** 32 - 1) + Args: + seed: int. The seed value for the random number generator. + Returns: + ti.math.uvec4: The random key for the LFSR88 random number generator. + """ + return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) -############################################# -# Random Number Generator: LFSR88 algorithm # -############################################# + @ti.func + def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): + """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). -@ti.func -def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). + Args: + key: The state value for the random number generator. - This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. + Returns: + ti.math.uvec4: The next random key. + """ + b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) + s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b + b = ((key[1] << 2) ^ key[1]) >> 25 + s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b + b = ((key[2] << 3) ^ key[2]) >> 11 + s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b + return ti.math.uvec4(s1, s2, s3, b) - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3 MUST be larger than - 1, 7, and 15 respectively. - */ + @ti.func + def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): + """ + Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. - Args: - seed: int. The seed value for the random number generator. + Args: + key: The state value for the random number generator. + mu: The mean of the normal distribution. + sigma: The standard deviation of the normal distribution. + epsilon: The epsilon value to avoid log(0). + """ - Returns: - ti.math.uvec4: The random key for the LFSR88 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) + key, r = lfsr88_randn(key, epsilon) + return key, mu + sigma * r -@ti.func -def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). + @ti.func + def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): + """ + Generate a random number with the standard normal distribution using the LFSR88 algorithm. - Args: - key: The state value for the random number generator. + Args: + key: The state value for the random number generator. + epsilon: The epsilon value to avoid log(0). - Returns: - ti.math.uvec4: The next random key. - """ - b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) - s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b - b = ((key[1] << 2) ^ key[1]) >> 25 - s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b - b = ((key[2] << 3) ^ key[2]) >> 11 - s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b - return ti.math.uvec4(s1, s2, s3, b) + References: + Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + """ -@ti.func -def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. + key, u1 = lfsr88_rand(key) + key, u2 = lfsr88_rand(key) - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ + # Ensure state1 is not zero to avoid log(0) + u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - key, r = lfsr88_randn(key, epsilon) - return key, mu + sigma * r + # Normalize the uniform samples + mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) + # Box-Muller transform + # z1 = mag * ti.cos(2 * ti.math.pi * u2) + z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) -@ti.func -def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with the standard normal distribution using the LFSR88 algorithm. + return key, z2 - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + @ti.func + def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. - """ + Parameters: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) - key, u1 = lfsr88_rand(key) - key, u2 = lfsr88_rand(key) - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) + @ti.func + def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): + key = lfsr88_next_key(key) + return key, dtype(key[0] ^ key[1] ^ key[2]) - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + @ti.func + def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. - return key, z2 + Args: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + return key, ti.cast(r * (high - low) + low, defaults.ti_float) -@ti.func -def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. + @ti.func + def lfsr88_rand(key: ti.types.vector(4, ti.u32)): + """ + Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) + Args: + key: The state value used for random number generation. + """ + key = lfsr88_next_key(key) + return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) -@ti.func -def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): - key = lfsr88_next_key(key) - return key, dtype(key[0] ^ key[1] ^ key[2]) + ############################################## + # Random Number Generator: LFSR113 algorithm # + ############################################## + @ti.func + def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): + """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). -@ti.func -def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. + This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) + Source: + https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c + /**** VERY IMPORTANT **** : + The initial seeds s1, s2, s3, s4 MUST be larger than + 1, 7, 15, and 127 respectively. + */ -@ti.func -def lfsr88_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. + Args: + seed: int. The seed value for the random number generator. - Args: - key: The state value used for random number generation. - """ - key = lfsr88_next_key(key) - return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + Returns: + ti.math.uvec4: The random key for the LFSR113 random number generator. + """ + return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) -############################################## -# Random Number Generator: LFSR113 algorithm # -############################################## + @ti.func + def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): + """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). + Args: + key: The state value for the random number generator. -@ti.func -def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). + Returns: + ti.math.uvec4: The next random key. + """ + z1 = key[0] + z2 = key[1] + z3 = key[2] + z4 = key[3] + b = ((z1 << 6) ^ z1) >> 13 + z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) + b = ((z2 << 2) ^ z2) >> 27 + z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) + b = ((z3 << 13) ^ z3) >> 21 + z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) + b = ((z4 << 3) ^ z4) >> 12 + z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) + return ti.math.uvec4(z1, z2, z3, z4) - This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c + @ti.func + def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): + """ + Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3, s4 MUST be larger than - 1, 7, 15, and 127 respectively. - */ + Args: + key: The state value for the random number generator. + mu: The mean of the normal distribution. + sigma: The standard deviation of the normal distribution. + epsilon: The epsilon value to avoid log(0). + """ - Args: - seed: int. The seed value for the random number generator. + key, r = lfsr113_randn(key, epsilon) + return key, ti.cast(mu + sigma * r, defaults.ti_float) - Returns: - ti.math.uvec4: The random key for the LFSR113 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) + @ti.func + def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): + """ + Generate a random number with standard normal distribution using the LFSR113 algorithm. -@ti.func -def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). + Args: + key: The state value for the random number generator. + epsilon: The epsilon value to avoid log(0). - Args: - key: The state value for the random number generator. + References: + Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method - Returns: - ti.math.uvec4: The next random key. - """ - z1 = key[0] - z2 = key[1] - z3 = key[2] - z4 = key[3] - b = ((z1 << 6) ^ z1) >> 13 - z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) - b = ((z2 << 2) ^ z2) >> 27 - z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) - b = ((z3 << 13) ^ z3) >> 21 - z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) - b = ((z4 << 3) ^ z4) >> 12 - z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) - return ti.math.uvec4(z1, z2, z3, z4) + """ + key, u1 = lfsr113_rand(key) + key, u2 = lfsr113_rand(key) -@ti.func -def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. + # Ensure state1 is not zero to avoid log(0) + u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ + # Normalize the uniform samples + mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - key, r = lfsr113_randn(key, epsilon) - return key, ti.cast(mu + sigma * r, defaults.ti_float) + # Box-Muller transform + # z1 = mag * ti.cos(2 * ti.math.pi * u2) + z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + return key, z2 -@ti.func -def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with standard normal distribution using the LFSR113 algorithm. - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). + @ti.func + def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + Parameters: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr113_next_key(key) + return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) - """ - key, u1 = lfsr113_rand(key) - key, u2 = lfsr113_rand(key) + @ti.func + def lfsr113_randint(key: ti.types.vector(4, ti.u32)): + key = lfsr113_next_key(key) + return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) + @ti.func + def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + Args: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + return key, ti.cast(r * (high - low) + low, defaults.ti_float) + + + @ti.func + def lfsr113_rand(key: ti.types.vector(4, ti.u32)): + """ + Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. - return key, z2 + Args: + key: The state value used for random number generation. + """ + key = lfsr113_next_key(key) + return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) -@ti.func -def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. + ########################### + # Reductions: warp reduce # + ########################### - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr113_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) + @ti.func + def warp_reduce_sum_all(val): + """ + Warp reduce sum. + Args: + val (float): The value to be reduced. -@ti.func -def lfsr113_randint(key: ti.types.vector(4, ti.u32)): - key = lfsr113_next_key(key) - return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) + Returns: + float: The reduced value. + """ + for i in ti.static(range(1, 32)): + val += ti.static(ti.simt.warp.shfl_xor(val, i)) + return val -@ti.func -def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. + @ti.func + def warp_reduce_sum(val): + """ + Warp reduce sum. - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) - - -@ti.func -def lfsr113_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. + Args: + val (float): The value to be reduced. - Args: - key: The state value used for random number generation. - """ - key = lfsr113_next_key(key) - return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + Returns: + float: The reduced value. + """ + for offset in ti.static((16, 8, 4, 2, 1)): + val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) + return val -########################### -# Reductions: warp reduce # -########################### - - -@ti.func -def warp_reduce_sum_all(val): - """ - Warp reduce sum. - - Args: - val (float): The value to be reduced. - - Returns: - float: The reduced value. - """ - for i in ti.static(range(1, 32)): - val += ti.static(ti.simt.warp.shfl_xor(val, i)) - return val - - -@ti.func -def warp_reduce_sum(val): - """ - Warp reduce sum. - - Args: - val (float): The value to be reduced. - - Returns: - float: The reduced value. - """ - for offset in ti.static((16, 8, 4, 2, 1)): - val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) - return val +else: + for func in __all__: + globals()[func] = raise_taichi_not_found \ No newline at end of file diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py index dd6865e64..6f2411ee8 100644 --- a/brainpy/_src/tests/test_dyn_runner.py +++ b/brainpy/_src/tests/test_dyn_runner.py @@ -1,134 +1,133 @@ -# -*- coding: utf-8 -*- - - -import unittest -import brainpy as bp -import brainpy.math as bm - - -class TestDSRunner(unittest.TestCase): - def test1(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self): - self.i += 1 - - ds = ExampleDS() - runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_t_and_dt(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self): - self.i += 1 * bp.share['dt'] - - runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_DSView(self): - class EINet(bp.Network): - def __init__(self, scale=1.0, method='exp_auto'): - super(EINet, self).__init__() - - # network size - num_exc = int(800 * scale) - num_inh = int(200 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - self.E = bp.neurons.LIF(num_exc, **pars, method=method) - self.I = bp.neurons.LIF(num_inh, **pars, method=method) - self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. - self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - - bm.random.seed() - - net = EINet(scale=1., method='exp_auto') - # with JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) - - # without JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) - - - -class TestMemoryEfficient(unittest.TestCase): - pass - - - - - - -# class TestMonitor(TestCase): -# def test_1d_array(self): -# try1 = TryGroup(monitors=['a']) -# try1.a = np.ones(1) -# try1.run(100.) -# -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 1 -# assert np.allclose(np.arange(2, 1002).reshape((-1, 1)), try1.mon.a) -# -# def test_2d_array(): -# set(dt=0.1) -# try1 = TryGroup(monitors=['a']) -# try1.a = np.ones((2, 2)) -# try1.run(100.) -# -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 -# series = np.arange(2, 1002).reshape((-1, 1)) -# series = np.repeat(series, 4, axis=1) -# assert np.allclose(series, try1.mon.a) -# -# def test_monitor_with_every(): -# set(dt=0.1) -# -# # try1: 2d array -# try1 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try1.run(100.) -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# series = np.repeat(series, 4, axis=1) -# assert np.allclose(series, try1.mon.a) -# -# # try2: 1d array -# try2 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try2.a = np.array([1., 1.]) -# try2.run(100.) -# assert np.ndim(try2.mon.a) == 2 and np.shape(try2.mon.a)[1] == 2 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# series = np.repeat(series, 2, axis=1) -# assert np.allclose(series, try2.mon.a) -# -# # try2: scalar -# try3 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try3.a = 1. -# try3.run(100.) -# assert np.ndim(try3.mon.a) == 2 and np.shape(try3.mon.a)[1] == 1 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# assert np.allclose(series, try3.mon.a) +# -*- coding: utf-8 -*- + +import pytest +import unittest +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class TestDSRunner(unittest.TestCase): + def test1(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1 + + ds = ExampleDS() + runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_t_and_dt(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1 * bp.share['dt'] + + runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_DSView(self): + class EINet(bp.Network): + def __init__(self, scale=1.0, method='exp_auto'): + super(EINet, self).__init__() + + # network size + num_exc = int(800 * scale) + num_inh = int(200 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) + self.E = bp.neurons.LIF(num_exc, **pars, method=method) + self.I = bp.neurons.LIF(num_inh, **pars, method=method) + self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. + self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + + bm.random.seed() + + net = EINet(scale=1., method='exp_auto') + # with JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) + + # without JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) + + +class TestMemoryEfficient(unittest.TestCase): + pass + +# class TestMonitor(TestCase): +# def test_1d_array(self): +# try1 = TryGroup(monitors=['a']) +# try1.a = np.ones(1) +# try1.run(100.) +# +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 1 +# assert np.allclose(np.arange(2, 1002).reshape((-1, 1)), try1.mon.a) +# +# def test_2d_array(): +# set(dt=0.1) +# try1 = TryGroup(monitors=['a']) +# try1.a = np.ones((2, 2)) +# try1.run(100.) +# +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 +# series = np.arange(2, 1002).reshape((-1, 1)) +# series = np.repeat(series, 4, axis=1) +# assert np.allclose(series, try1.mon.a) +# +# def test_monitor_with_every(): +# set(dt=0.1) +# +# # try1: 2d array +# try1 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try1.run(100.) +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# series = np.repeat(series, 4, axis=1) +# assert np.allclose(series, try1.mon.a) +# +# # try2: 1d array +# try2 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try2.a = np.array([1., 1.]) +# try2.run(100.) +# assert np.ndim(try2.mon.a) == 2 and np.shape(try2.mon.a)[1] == 2 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# series = np.repeat(series, 2, axis=1) +# assert np.allclose(series, try2.mon.a) +# +# # try2: scalar +# try3 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try3.a = 1. +# try3.run(100.) +# assert np.ndim(try3.mon.a) == 2 and np.shape(try3.mon.a)[1] == 1 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# assert np.allclose(series, try3.mon.a) diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py deleted file mode 100644 index cbc710dba..000000000 --- a/brainpy/_src/tools/functions.py +++ /dev/null @@ -1,192 +0,0 @@ -import inspect -from functools import partial -from operator import attrgetter -from types import MethodType - -__all__ = [ - 'compose', 'pipe' -] - - -def identity(x): - """ Identity function. Return x - - >>> identity(3) - 3 - """ - return x - - -def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): - """ Like @property, but returns ``classval`` when used as a class attribute - - >>> class MyClass(object): - ... '''The class docstring''' - ... @instanceproperty(classval=__doc__) - ... def __doc__(self): - ... return 'An object docstring' - ... @instanceproperty - ... def val(self): - ... return 42 - ... - >>> MyClass.__doc__ - 'The class docstring' - >>> MyClass.val is None - True - >>> obj = MyClass() - >>> obj.__doc__ - 'An object docstring' - >>> obj.val - 42 - """ - if fget is None: - return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, - classval=classval) - return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, - classval=classval) - - -class InstanceProperty(property): - """ Like @property, but returns ``classval`` when used as a class attribute - - Should not be used directly. Use ``instanceproperty`` instead. - """ - - def __init__(self, fget=None, fset=None, fdel=None, doc=None, - classval=None): - self.classval = classval - property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) - - def __get__(self, obj, type=None): - if obj is None: - return self.classval - return property.__get__(self, obj, type) - - def __reduce__(self): - state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) - return InstanceProperty, state - - -class Compose(object): - """ A composition of functions - - See Also: - compose - """ - __slots__ = 'first', 'funcs' - - def __init__(self, funcs): - funcs = tuple(reversed(funcs)) - self.first = funcs[0] - self.funcs = funcs[1:] - - def __call__(self, *args, **kwargs): - ret = self.first(*args, **kwargs) - for f in self.funcs: - ret = f(ret) - return ret - - def __getstate__(self): - return self.first, self.funcs - - def __setstate__(self, state): - self.first, self.funcs = state - - @instanceproperty(classval=__doc__) - def __doc__(self): - def composed_doc(*fs): - """Generate a docstring for the composition of fs. - """ - if not fs: - # Argument name for the docstring. - return '*args, **kwargs' - - return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) - - try: - return ( - 'lambda *args, **kwargs: ' + - composed_doc(*reversed((self.first,) + self.funcs)) - ) - except AttributeError: - # One of our callables does not have a `__name__`, whatever. - return 'A composition of functions' - - @property - def __name__(self): - try: - return '_of_'.join( - (f.__name__ for f in reversed((self.first,) + self.funcs)) - ) - except AttributeError: - return type(self).__name__ - - def __repr__(self): - return '{.__class__.__name__}{!r}'.format( - self, tuple(reversed((self.first,) + self.funcs))) - - def __eq__(self, other): - if isinstance(other, Compose): - return other.first == self.first and other.funcs == self.funcs - return NotImplemented - - def __ne__(self, other): - equality = self.__eq__(other) - return NotImplemented if equality is NotImplemented else not equality - - def __hash__(self): - return hash(self.first) ^ hash(self.funcs) - - # Mimic the descriptor behavior of python functions. - # i.e. let Compose be called as a method when bound to a class. - # adapted from - # docs.python.org/3/howto/descriptor.html#functions-and-methods - def __get__(self, obj, objtype=None): - return self if obj is None else MethodType(self, obj) - - # introspection with Signature is only possible from py3.3+ - @instanceproperty - def __signature__(self): - base = inspect.signature(self.first) - last = inspect.signature(self.funcs[-1]) - return base.replace(return_annotation=last.return_annotation) - - __wrapped__ = instanceproperty(attrgetter('first')) - - -def compose(*funcs): - """ Compose functions to operate in series. - - Returns a function that applies other functions in sequence. - - Functions are applied from right to left so that - ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. - - If no arguments are provided, the identity function (f(x) = x) is returned. - - >>> inc = lambda i: i + 1 - >>> compose(str, inc)(3) - '4' - """ - if not funcs: - return identity - if len(funcs) == 1: - return funcs[0] - else: - return Compose(funcs) - - -def pipe(*funcs): - """ Pipe a value through a sequence of functions - - I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` - - We think of the value as progressing through a pipe of several - transformations, much like pipes in UNIX - - - >>> double = lambda i: 2 * i - >>> pipe(double, str)(3) - '6' - """ - return compose(*reversed(funcs)) diff --git a/brainpy/_src/tools/progress.py b/brainpy/_src/tools/progress.py new file mode 100644 index 000000000..13b6a1574 --- /dev/null +++ b/brainpy/_src/tools/progress.py @@ -0,0 +1,519 @@ +"""Python utilities required by Keras.""" + +import binascii +import codecs +import importlib +import marshal +import os +import re +import sys +import time +import types as python_types + +import numpy as np + + +# isort: off + + +def func_dump(func): + """Serializes a user defined function. + + Args: + func: the function to serialize. + + Returns: + A tuple `(code, defaults, closure)`. + """ + if os.name == "nt": + raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") + code = codecs.encode(raw_code, "base64").decode("ascii") + else: + raw_code = marshal.dumps(func.__code__) + code = codecs.encode(raw_code, "base64").decode("ascii") + defaults = func.__defaults__ + if func.__closure__: + closure = tuple(c.cell_contents for c in func.__closure__) + else: + closure = None + return code, defaults, closure + + +def func_load(code, defaults=None, closure=None, globs=None): + """Deserializes a user defined function. + + Args: + code: bytecode of the function. + defaults: defaults of the function. + closure: closure of the function. + globs: dictionary of global objects. + + Returns: + A function object. + """ + if isinstance(code, (tuple, list)): # unpack previous dump + code, defaults, closure = code + if isinstance(defaults, list): + defaults = tuple(defaults) + + def ensure_value_to_cell(value): + """Ensures that a value is converted to a python cell object. + + Args: + value: Any value that needs to be casted to the cell type + + Returns: + A value wrapped as a cell object (see function "func_load") + """ + + def dummy_fn(): + value # just access it so it gets captured in .__closure__ + + cell_value = dummy_fn.__closure__[0] + if not isinstance(value, type(cell_value)): + return cell_value + return value + + if closure is not None: + closure = tuple(ensure_value_to_cell(_) for _ in closure) + try: + raw_code = codecs.decode(code.encode("ascii"), "base64") + except (UnicodeEncodeError, binascii.Error): + raw_code = code.encode("raw_unicode_escape") + code = marshal.loads(raw_code) + if globs is None: + globs = globals() + return python_types.FunctionType( + code, globs, name=code.co_name, argdefs=defaults, closure=closure + ) + + +class Progbar: + """Displays a progress bar. + + Args: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that should *not* + be averaged over time. Metrics in this list will be displayed as-is. + All others will be averaged by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + unit_name: Display name for step counts (usually "step" or "sample"). + """ + + def __init__( + self, + target, + width=30, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name="step", + ): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + self.unit_name = unit_name + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ( + (hasattr(sys.stdout, "isatty") and sys.stdout.isatty()) + or "ipykernel" in sys.modules + or "posix" in sys.modules + or "PYCHARM_HOSTED" in os.environ + ) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + self._time_at_epoch_start = self._start + self._time_at_epoch_end = None + self._time_after_first_step = None + + def update(self, current, values=None, finalize=None): + """Updates the progress bar. + + Args: + current: Index of current step. + values: List of tuples: `(name, value_for_last_step)`. If `name` is + in `stateful_metrics`, `value_for_last_step` will be displayed + as-is. Else, an average of the metric over time will be + displayed. + finalize: Whether this is the last update for the progress bar. If + `None`, uses `current >= self.target`. Defaults to `None`. + """ + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + # In the case that progress bar doesn't have a target value in + # the first epoch, both on_batch_end and on_epoch_end will be + # called, which will cause 'current' and 'self._seen_so_far' to + # have the same value. Force the minimal value to 1 here, + # otherwise stateful_metric will be 0s. + value_base = max(current - self._seen_so_far, 1) + if k not in self._values: + self._values[k] = [v * value_base, value_base] + else: + self._values[k][0] += v * value_base + self._values[k][1] += value_base + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + message = "" + now = time.time() + info = f" - {now - self._start:.0f}s" + if current == self.target: + self._time_at_epoch_end = now + if self.verbose == 1: + if now - self._last_update < self.interval and not finalize: + return + + prev_total_width = self._total_width + if self._dynamic_display: + message += "\b" * prev_total_width + message += "\r" + else: + message += "\n" + + if self.target is not None: + numdigits = int(np.log10(self.target)) + 1 + bar = ("%" + str(numdigits) + "d/%d [") % (current, self.target) + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += "=" * (prog_width - 1) + if current < self.target: + bar += ">" + else: + bar += "=" + bar += "." * (self.width - prog_width) + bar += "]" + else: + bar = "%7d/Unknown" % current + + self._total_width = len(bar) + message += bar + + time_per_unit = self._estimate_step_duration(current, now) + + if self.target is None or finalize: + info += self._format_time(time_per_unit, self.unit_name) + else: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = "%d:%02d:%02d" % ( + eta // 3600, + (eta % 3600) // 60, + eta % 60, + ) + elif eta > 60: + eta_format = "%d:%02d" % (eta // 60, eta % 60) + else: + eta_format = "%ds" % eta + + info = f" - ETA: {eta_format}" + + for k in self._values_order: + info += f" - {k}:" + if isinstance(self._values[k], list): + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if abs(avg) > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + else: + info += f" {self._values[k]}" + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += " " * (prev_total_width - self._total_width) + + if finalize: + info += "\n" + + message += info + print_msg(message, line_break=False) + message = "" + + elif self.verbose == 2: + if finalize: + numdigits = int(np.log10(self.target)) + 1 + count = ("%" + str(numdigits) + "d/%d") % (current, self.target) + info = count + info + for k in self._values_order: + info += f" - {k}:" + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if avg > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + if self._time_at_epoch_end: + time_per_epoch = ( + self._time_at_epoch_end - self._time_at_epoch_start + ) + avg_time_per_step = time_per_epoch / self.target + self._time_at_epoch_start = now + self._time_at_epoch_end = None + info += " -" + self._format_time(time_per_epoch, "epoch") + info += " -" + self._format_time( + avg_time_per_step, self.unit_name + ) + info += "\n" + message += info + print_msg(message, line_break=False) + message = "" + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + def _format_time(self, time_per_unit, unit_name): + """format a given duration to display to the user. + + Given the duration, this function formats it in either milliseconds + or seconds and displays the unit (i.e. ms/step or s/epoch) + Args: + time_per_unit: the duration to display + unit_name: the name of the unit to display + Returns: + a string with the correctly formatted duration and units + """ + formatted = "" + if time_per_unit >= 1 or time_per_unit == 0: + formatted += f" {time_per_unit:.0f}s/{unit_name}" + elif time_per_unit >= 1e-3: + formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}" + else: + formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}" + return formatted + + def _estimate_step_duration(self, current, now): + """Estimate the duration of a single step. + + Given the step number `current` and the corresponding time `now` this + function returns an estimate for how long a single step takes. If this + is called before one step has been completed (i.e. `current == 0`) then + zero is given as an estimate. The duration estimate ignores the duration + of the (assumed to be non-representative) first step for estimates when + more steps are available (i.e. `current>1`). + + Args: + current: Index of current step. + now: The current time. + + Returns: Estimate of the duration of a single step. + """ + if current: + # there are a few special scenarios here: + # 1) somebody is calling the progress bar without ever supplying + # step 1 + # 2) somebody is calling the progress bar and supplies step one + # multiple times, e.g. as part of a finalizing call + # in these cases, we just fall back to the simple calculation + if self._time_after_first_step is not None and current > 1: + time_per_unit = (now - self._time_after_first_step) / ( + current - 1 + ) + else: + time_per_unit = (now - self._start) / current + + if current == 1: + self._time_after_first_step = now + return time_per_unit + else: + return 0 + + def _update_stateful_metrics(self, stateful_metrics): + self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) + + +def make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). + + Args: + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + + Returns: + A list of tuples of array indices. + """ + num_batches = int(np.ceil(size / float(batch_size))) + return [ + (i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches) + ] + + +def slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. + + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + + Can also work on list/array of indices: `slice_arrays(x, indices)` + + Args: + arrays: Single array or list of arrays. + start: can be an integer index (start index) or a list/array of indices + stop: integer (stop index); should be None if `start` was a list. + + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. + """ + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError( + "The stop argument has to be None if the value of start " + f"is a list. Received start={start}, stop={stop}" + ) + elif isinstance(arrays, list): + if hasattr(start, "__len__"): + # hdf5 datasets only support list objects as indices + if hasattr(start, "shape"): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + return [ + None + if x is None + else None + if not hasattr(x, "__getitem__") + else x[start:stop] + for x in arrays + ] + else: + if hasattr(start, "__len__"): + if hasattr(start, "shape"): + start = start.tolist() + return arrays[start] + if hasattr(start, "__getitem__"): + return arrays[start:stop] + return [None] + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Args: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] + + +def to_snake_case(name): + intermediate = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower() + # If the class is private the name starts with "_" which is not secure + # for creating scopes. We prefix the name with "private" in this case. + if insecure[0] != "_": + return insecure + return "private" + insecure + + +def check_for_unexpected_keys(name, input_dict, expected_values): + unknown = set(input_dict.keys()).difference(expected_values) + if unknown: + raise ValueError( + f"Unknown entries in {name} dictionary: {list(unknown)}. " + f"Only expected following keys: {expected_values}" + ) + + +def validate_kwargs( + kwargs, allowed_kwargs, error_message="Keyword argument not understood:" +): + """Checks that all keyword arguments are in the set of allowed keys.""" + for kwarg in kwargs: + if kwarg not in allowed_kwargs: + raise TypeError(error_message, kwarg) + + +def default(method): + """Decorates a method to detect overrides in subclasses.""" + method._is_default = True + return method + + +def is_default(method): + """Check if a method is decorated with the `default` wrapper.""" + return getattr(method, "_is_default", False) + + +def populate_dict_with_module_objects(target_dict, modules, obj_filter): + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if obj_filter(obj): + target_dict[name] = obj + + +class LazyLoader(python_types.ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies.""" + + def __init__(self, local_name, parent_module_globals, name): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + super().__init__(name) + + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on + # lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + +def print_msg(message, line_break=True): + """Print the message to absl logging or stdout.""" + if line_break: + sys.stdout.write(message + "\n") + else: + sys.stdout.write(message) + sys.stdout.flush() diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py deleted file mode 100644 index c285e561a..000000000 --- a/brainpy/_src/tools/tests/test_functions.py +++ /dev/null @@ -1,24 +0,0 @@ - -import unittest - -import brainpy as bp -import brainpy.math as bm - - -class TestFunction(unittest.TestCase): - def test_compose(self): - f = lambda a: a + 1 - g = lambda a: a * 10 - fun1 = bp.tools.compose(f, g) - fun2 = bp.tools.pipe(g, f) - - arr = bm.random.randn(10) - r1 = fun1(arr) - r2 = fun2(arr) - groundtruth = f(g(arr)) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, groundtruth)) - bm.clear_buffer_memory() - - - diff --git a/brainpy/errors.py b/brainpy/errors.py index e59bb326c..453c9c818 100644 --- a/brainpy/errors.py +++ b/brainpy/errors.py @@ -38,7 +38,12 @@ class AnalyzerError(BrainPyError): class PackageMissingError(BrainPyError): """The package missing error. """ - pass + + @classmethod + def by_purpose(cls, name, purpose): + err = (f'"{name}" must be installed when the user wants to use {purpose}. \n' + f'Please install through "pip install {name}".') + return cls(err) class BackendNotInstalled(BrainPyError): @@ -236,9 +241,5 @@ def __init__(self, name): ''') - - class SharedArgError(BrainPyError): pass - - diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 02f671345..9a64f9f25 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -1,103 +1,102 @@ -# -*- coding: utf-8 -*- - -# data structure -from .ndarray import * -from .delayvars import * -from .interoperability import * -from .datatypes import * -from .compat_numpy import * -from .compat_tensorflow import * -from .compat_pytorch import * -from .einops import * - -# functions -from .activations import * -from . import activations - -# operators -from .pre_syn_post import * -from .op_register import * -from . import surrogate, event, sparse, jitconn - -# Variable and Objects for object-oriented JAX transformations -from .oo_transform import * - -# environment settings -from .modes import * -from .environment import * -from .scales import * -from .others import * - -# high-level numpy operations -from . import fft -from . import linalg -from . import random - -# taichi operations -from . import tifunc - -# others -from . import sharding - -import jax.numpy as jnp -from jax import config - -del jnp, config - -from brainpy._src.math.surrogate._compt import ( - spike_with_sigmoid_grad as spike_with_sigmoid_grad, - spike_with_linear_grad as spike_with_linear_grad, - spike_with_gaussian_grad as spike_with_gaussian_grad, - spike_with_mg_grad as spike_with_mg_grad, -) - -from brainpy._src.math import defaults -from brainpy._src.deprecations import deprecation_getattr -__deprecations = { - "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", - sparse.seg_matmul), - 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_homo instead.", - jitconn.event_mv_prob_homo), - 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", - jitconn.event_mv_prob_uniform), - 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_normal instead.", - jitconn.event_mv_prob_normal), - 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_homo instead.", - jitconn.mv_prob_homo), - 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_uniform instead.", - jitconn.mv_prob_uniform), - 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_normal instead.", - jitconn.mv_prob_normal), - 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " - "Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'cusparse_coo_matvec': ("brainpy.math.cusparse_coo_matvec is deprecated. " - "Use brainpy.math.sparse.coomv instead.", - sparse.coomv), - 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " - "Use brainpy.math.sparse.coo_to_csr instead.", - sparse.coo_to_csr), - 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " - "Use brainpy.math.sparse.csr_to_coo instead.", - sparse.csr_to_coo), - 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " - "Use brainpy.math.sparse.csr_to_dense instead.", - sparse.csr_to_dense), - 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " - "Use brainpy.math.event.csr_to_dense instead.", - event.csrmv), - 'event_info': ("brainpy.math.event_info is deprecated. " - "Use brainpy.math.event.info instead.", - event.info), -} - -__getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) -del deprecation_getattr, defaults +# -*- coding: utf-8 -*- + +# data structure +from .ndarray import * +from .delayvars import * +from .interoperability import * +from .datatypes import * +from .compat_numpy import * +from .compat_tensorflow import * +from .compat_pytorch import * +from .einops import * + +# functions +from .activations import * +from . import activations + +# operators +from .pre_syn_post import * +from .op_register import * +from . import surrogate, event, sparse, jitconn + +# Variable and Objects for object-oriented JAX transformations +from .oo_transform import * + +# environment settings +from .modes import * +from .environment import * +from .scales import * +from .others import * + +# high-level numpy operations +from . import fft +from . import linalg +from . import random + +# taichi operations +from . import tifunc + +# others +from . import sharding + +import jax.numpy as jnp +from jax import config + +del jnp, config + +from brainpy._src.math.surrogate._compt import ( + spike_with_sigmoid_grad as spike_with_sigmoid_grad, + spike_with_linear_grad as spike_with_linear_grad, + spike_with_gaussian_grad as spike_with_gaussian_grad, + spike_with_mg_grad as spike_with_mg_grad, +) + +from brainpy._src.math import defaults +from brainpy._src.deprecations import deprecation_getattr +from brainpy._src.dependency_check import import_taichi, import_numba + +import_taichi(error_if_not_found=False) +import_numba(error_if_not_found=False) + +__deprecations = { + "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", + sparse.seg_matmul), + 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_homo instead.", + jitconn.event_mv_prob_homo), + 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", + jitconn.event_mv_prob_uniform), + 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_normal instead.", + jitconn.event_mv_prob_normal), + 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_homo instead.", + jitconn.mv_prob_homo), + 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_uniform instead.", + jitconn.mv_prob_uniform), + 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_normal instead.", + jitconn.mv_prob_normal), + 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " + "Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " + "Use brainpy.math.sparse.coo_to_csr instead.", + sparse.coo_to_csr), + 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " + "Use brainpy.math.sparse.csr_to_coo instead.", + sparse.csr_to_coo), + 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " + "Use brainpy.math.sparse.csr_to_dense instead.", + sparse.csr_to_dense), + 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " + "Use brainpy.math.event.csr_to_dense instead.", + event.csrmv), +} + +__getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) +del deprecation_getattr, defaults diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index 3b0c3f517..e4570f6fd 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -12,7 +12,7 @@ arccos as arccos, acosh as acosh, arccosh as arccosh, - # add as add, + add as add, addcdiv as addcdiv, addcmul as addcmul, angle as angle, diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 0a17cae7c..02e98b8f3 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,5 +1,3 @@ - from brainpy._src.math.event import ( csrmv as csrmv, - info as info, ) diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 90a028b7e..a87d27d58 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -1,10 +1,10 @@ -from brainpy._src.math.jitconn import ( - event_mv_prob_homo as event_mv_prob_homo, - event_mv_prob_uniform as event_mv_prob_uniform, - event_mv_prob_normal as event_mv_prob_normal, - - mv_prob_homo as mv_prob_homo, - mv_prob_uniform as mv_prob_uniform, - mv_prob_normal as mv_prob_normal, -) - +from brainpy._src.math.jitconn import ( + event_mv_prob_homo as event_mv_prob_homo, + event_mv_prob_uniform as event_mv_prob_uniform, + event_mv_prob_normal as event_mv_prob_normal, + + mv_prob_homo as mv_prob_homo, + mv_prob_uniform as mv_prob_uniform, + mv_prob_normal as mv_prob_normal, +) + diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 7654731d8..548a987d0 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -59,7 +59,3 @@ eval_shape as eval_shape, ) -from brainpy._src.math.object_transform.variables import ( - VariableStack as VariableStack, -) - diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index a48268ef4..c0fcb67ae 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- - - -from brainpy._src.math.op_register import ( - CustomOpByNumba, - compile_cpu_signature_with_numba, - clean_caches, - check_kernels_count, -) - -from brainpy._src.math.op_register.base import XLACustomOp -from brainpy._src.math.op_register.ad_support import defjvp - - +# -*- coding: utf-8 -*- +from brainpy._src.math.op_register import ( + CustomOpByNumba, + compile_cpu_signature_with_numba, + clean_caches, + check_kernels_count, +) + +from brainpy._src.math.op_register.base import XLACustomOp +from brainpy._src.math.op_register.ad_support import defjvp + + diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 1380a9e9c..aa86679ec 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,8 +1,9 @@ -from brainpy._src.math.sparse import ( - csrmv, - coomv, +from brainpy._src.math.sparse import ( seg_matmul, +) +from brainpy._src.math.sparse import ( + csrmv, csr_to_dense as csr_to_dense, csr_to_coo as csr_to_coo, diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py index 63f3cbe45..bea49c220 100644 --- a/brainpy/math/tifunc.py +++ b/brainpy/math/tifunc.py @@ -1,26 +1,25 @@ -# -*- coding: utf-8 -*- - -from brainpy._src.math.tifunc import ( - taichi_lcg_rand, - - # warp reduction primitives - warp_reduce_sum, - - # random number generator - lfsr88_key, - lfsr88_next_key, - lfsr88_normal, - lfsr88_randn, - lfsr88_random_integers, - lfsr88_randint, - lfsr88_uniform, - lfsr88_rand, - lfsr113_key, - lfsr113_next_key, - lfsr113_normal, - lfsr113_randn, - lfsr113_random_integers, - lfsr113_randint, - lfsr113_uniform, - lfsr113_rand -) +# -*- coding: utf-8 -*- + +from brainpy._src.math.tifunc import ( + + # warp reduction primitives + warp_reduce_sum, + + # random number generator + lfsr88_key, + lfsr88_next_key, + lfsr88_normal, + lfsr88_randn, + lfsr88_random_integers, + lfsr88_randint, + lfsr88_uniform, + lfsr88_rand, + lfsr113_key, + lfsr113_next_key, + lfsr113_normal, + lfsr113_randn, + lfsr113_random_integers, + lfsr113_randint, + lfsr113_uniform, + lfsr113_rand +) diff --git a/brainpy/tools.py b/brainpy/tools.py index 233269dc5..0f3a4c0ef 100644 --- a/brainpy/tools.py +++ b/brainpy/tools.py @@ -45,9 +45,4 @@ ) -from brainpy._src.tools.functions import ( - compose as compose, - pipe as pipe, -) - diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst index 0b78315ab..5c8cba0fd 100644 --- a/docs/advanced_tutorials.rst +++ b/docs/advanced_tutorials.rst @@ -3,52 +3,13 @@ Advanced Tutorials This section contains tutorials that illustrate more advanced features of BrainPy. -Advanced Math -------------- .. toctree:: - :maxdepth: 1 - - tutorial_advanced/compilation.ipynb - tutorial_advanced/differentiation.ipynb - - -Interoperation --------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/integrate_flax_into_brainpy.ipynb - tutorial_advanced/integrate_bp_lif_into_flax.ipynb - tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb - - -Brain Dynamics Dedicated Operators ----------------------------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/operator_custom_with_numba.ipynb - tutorial_advanced/operator_custom_with_taichi.ipynb - - -Developer Guides ----------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/contributing.md - - -Others ------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/advanced_lowdim_analysis.ipynb + :maxdepth: 2 + tutorial_advanced/1_advanced_math.rst + tutorial_advanced/2_interoperation.rst + tutorial_advanced/3_dedicated_operators.rst + tutorial_advanced/4_developer_guides.rst + tutorial_advanced/5_others.rst diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 9ed9cf46a..754e0d81d 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -77,5 +77,4 @@ Helpers for Object-oriented Transformations :template: classtemplate.rst eval_shape - VariableStack diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index 2e0bb1905..46ce3822f 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -10,285 +10,71 @@ Installation Linux, and MacOS. It only relies on Python libraries. -Installation with pip ---------------------- +Minimum requirements +-------------------- -You can install ``BrainPy`` from the `pypi `_. -To do so, use: +To install brainpy with minimum requirements (only depends on ``jax``), you can use: .. code-block:: bash - pip install brainpy - -To update the latest BrainPy, you can use - -.. code-block:: bash - - pip install -U brainpy - - -If you want to install the pre-release version (the latest development version) -of BrainPy, you can use: - -.. code-block:: bash - - pip install --pre brainpy - - - -Installation from source ------------------------- - -If you decide not to use ``pip``, you can install ``BrainPy`` from -`GitHub `_, -or `OpenI `_. - -To do so, use: - -.. code-block:: bash - - pip install git+https://github.com/PKU-NIP-Lab/BrainPy + pip install brainpy[cpu_mini] # for CPU # or - pip install git+https://git.openi.org.cn/OpenI/BrainPy + pip install brainpy[cuda_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for GPU (Linux only) -Dependency 1: NumPy --------------------------------- -In order to make BrainPy work normally, users should install -several dependent Python packages. +CPU with all dependencies +------------------------- -The basic function of ``BrainPy`` only relies on `NumPy`_, which is very -easy to install through ``pip`` or ``conda``: +To install a CPU-only version of BrainPy, which might be useful for doing local development on a laptop, you can run .. code-block:: bash - pip install numpy - - # or - - conda install numpy - -Dependency 2: JAX ------------------ - -BrainPy relies on `JAX`_. JAX is a high-performance JIT compiler which enables -users to run Python code on CPU, GPU, and TPU devices. Core functionalities of -BrainPy (>=2.0.0) have been migrated to the JAX backend. - -Linux -^^^^^ - -Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or -later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS -systems are available at + pip install brainpy[cpu] -- for CPU: https://storage.googleapis.com/jax-releases/jax_releases.html -- for GPU: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -If you want to install a CPU-only version of `jax` and `jaxlib`, you can run +GPU with all dependencies +------------------------- -.. code-block:: bash - - pip install --upgrade "jax[cpu]" - -If you want to install JAX with both CPU and NVidia GPU support, you must first install -`CUDA`_ and `CuDNN`_, if they have already been installed. Next, run +BrainPy supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. +To install a GPU-only version of BrainPy, you can run .. code-block:: bash - # CUDA 12 installation - # Note: wheels only available on linux. - pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - - # CUDA 11 installation - # Note: wheels only available on linux. - pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - -In the event of a version mismatch error with JAX, such as encountering an error message like: - -.. code-block:: text + pip install brainpy[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0 + pip install brainpy[cuda11] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0 - CUDA backend failed to initialize: Found CUDA version 12000, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) -You will need to employ an alternative installation method that aligns with your environment's CUDA version. This can be achieved using the following commands: -.. code-block:: bash +``brainpylib`` +-------------- - # CUDA 12 installation - pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - # CUDA 11 installation - pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``brainpylib`` defines a set of useful operators for building and simulating spiking neural networks. -Alternatively, you can download the preferred release ".whl" file for jaxlib -from the above release links, and install it via ``pip``: +To install the ``brainpylib`` package on CPU devices, you can run .. code-block:: bash - pip install xxx-0.4.15-xxx.whl - - pip install jax==0.4.15 - -.. note:: - - Note that the versions of jaxlib and jax should be consistent. - - For example, if you are using jax==0.4.15, you would better install jax==0.4.15. - + pip install brainpylib -MacOS -^^^^^ -If you are using macOS Intel, we recommend you first to install the Miniconda Intel installer: +To install the ``brainpylib`` package on CUDA 11, you can run -1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.pkg -2. Then click the downloaded package and install it. - - -If you are using the latest M1 macOS version, you'd better to install the Miniconda M1 installer: - - -1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.pkg -2. Then click the downloaded package and install it. - - -Finally, you can install `jax` and `jaxlib` as the same as the Linux platform. .. code-block:: bash - pip install --upgrade "jax[cpu]" - - - -Windows -^^^^^^^ - -For **Windows** users with Python >= 3.9, `jax` and `jaxlib` can be installed -directly from the PyPi channel. - -.. code-block:: bash + pip install brainpylib-cu11x - pip install jax jaxlib +To install the ``brainpylib`` package on CUDA 12, you can run -For **Windows** users with Python <= 3.8, `jax` and `jaxlib` can be installed -from the community supports. Specifically, you can install `jax` and `jaxlib` through: .. code-block:: bash - pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html - -If you are using GPU, you can install GPU-versioned wheels through: - -.. code-block:: bash - - pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html - -Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by -downloading binary releases of JAX for Windows from -https://whls.blob.core.windows.net/unstable/index.html . -Then install it via ``pip``: - -.. code-block:: bash - - pip install xxx-0.4.15-xxx.whl - - pip install jax==0.4.15 - -WSL -^^^ - -Moreover, for Windows 10+ system, we recommend using `Windows Subsystem for Linux (WSL)`_. -The installation guide can be found in -`WSL Installation Guide for Windows 10/11 `_. -Then, you can install JAX in WSL just like the installation step in Linux/MacOs. - - -Dependency 3: brainpylib ------------------------- - -Many customized operators in BrainPy are implemented in ``brainpylib``. -``brainpylib`` can also be installed from pypi according to your devices. -For windows, Linux and MacOS users, ``brainpylib`` supports CPU operators. -You can install CPU-version `brainpylib` by: - -.. code-block:: bash - - # CPU installation - pip install --upgrade brainpylib - -For Nvidia GPU users, ``brainpylib`` only support Linux system and WSL2 subsystem. You can install the CUDA-version by using: - -.. code-block:: bash - - # CUDA 12 installation - pip install --upgrade brainpylib-cu12x - -.. code-block:: bash - - # CUDA 11 installation - pip install --upgrade brainpylib-cu11x - -Dependency 4: taichi ------------------------- -Now BrainPy supports customized operators implemented in `taichi`_. You can install the latest version of `taichi`_ by: - -.. code-block:: bash - - pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly - -.. _taichi: https://www.taichi-lang.org - -And you can try it in the `operator custom with taichi <../tutorial_advanced/operator_custom_with_taichi.html>`_ tutorial page -Attention: customized operators is still in the experimental stage. If you meet any problems, please contact us through the issue page. - -Running BrainPy with docker ------------------------- - -If you want to use BrainPy in docker, you can use the following command to pull the docker image: - -.. code:: bash - - docker pull brainpy/brainpy:latest - -You can then run the docker image by: - -.. code:: bash - - docker run -it --platform linux/amd64 brainpy/brainpy:latest - -Please notice that BrainPy docker image is based on the `ubuntu22.04` image, so it only support CPU version of BrainPy. - - -Running BrainPy online with binder ----------------------------------- - -Click on the following link to launch the Binder environment with the -BrainPy repository: - -|image1| - -Wait for the Binder environment to build. This might take a few moments. - -Once the environment is ready, you'll be redirected to a Jupyter -notebook interface within your web browser. - -.. |image1| image:: https://camo.githubusercontent.com/581c077bdbc6ca6899c86d0acc6145ae85e9d80e6f805a1071793dbe48917982/68747470733a2f2f6d7962696e6465722e6f72672f62616467655f6c6f676f2e737667 - :target: https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main - - -.. _NumPy: https://numpy.org/ -.. _Matplotlib: https://matplotlib.org/ -.. _JAX: https://github.com/google/jax -.. _Windows Subsystem for Linux (WSL): https://docs.microsoft.com/en-us/windows/wsl/about -.. _build JAX from source: https://jax.readthedocs.io/en/latest/developer.html -.. _SymPy: https://github.com/sympy/sympy -.. _Numba: https://numba.pydata.org/ -.. _CUDA: https://developer.nvidia.com/cuda-downloads -.. _CuDNN: https://developer.nvidia.com/CUDNN + pip install brainpylib-cu12x diff --git a/docs/toolboxes.rst b/docs/toolboxes.rst index cc3a38575..11bf53115 100644 --- a/docs/toolboxes.rst +++ b/docs/toolboxes.rst @@ -1,16 +1,7 @@ BDP Toolboxes ================== - - - This section contains detailed toolboxes BrainPy uses for brain dynamics modeling. - - -Differential Equations ------------------------ - - .. toctree:: :maxdepth: 1 @@ -19,34 +10,11 @@ Differential Equations tutorial_toolbox/fde_numerical_solvers tutorial_toolbox/dde_numerical_solvers tutorial_toolbox/joint_equations - - -Toolbox for Modeling -------------------- - -.. toctree:: - :maxdepth: 1 - tutorial_toolbox/synaptic_connections tutorial_toolbox/synaptic_weights - tutorial_toolbox/inputs - - -Toolbox for Training --------------------- - -.. toctree:: - :maxdepth: 1 - tutorial_toolbox/optimizers + tutorial_toolbox/state_saving_and_loading.ipynb + tutorial_toolbox/state_resetting.ipynb tutorial_toolbox/surrogate_gradient + tutorial_toolbox/inputs - -State Resetting, Saving and Loading ------------------------------------ - -.. toctree:: - :maxdepth: 1 - - tutorial_toolbox/state_saving_and_loading.ipynb - tutorial_toolbox/state_resetting.ipynb \ No newline at end of file diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 57d18332b..7c9a1c876 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -3,76 +3,11 @@ BDP Tutorials This section contains tutorials on how to use BrainPy to accomplish model building, simulation, training, and analysis. - -Math Foundation ---------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_math/variables - tutorial_math/control_flows - tutorial_math/Numpy_like_Operations.ipynb - tutorial_math/Dedicated_Operators.ipynb - tutorial_math/einops_in_brainpy.ipynb - - -Model Building with Existing Modules ------------------------------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_building/overview_of_dynamic_model - tutorial_building/build_conductance_neurons_v2.ipynb - tutorial_building/phenon_synapse_models.ipynb - tutorial_building/kinetic_synapse_models.ipynb - tutorial_building/build_network_models - - -Model Building by Customizing New Modules ------------------------------------------ - -.. toctree:: - :maxdepth: 1 - - tutorial_building/customize_neuron_models - tutorial_building/customize_synapse_models - tutorial_building/how_to_customze_a_synapse.ipynb - - -Model Simulation ----------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_simulation/simulation_dsrunner.ipynb - tutorial_simulation/parallel_for_parameter_exploration.ipynb - tutorial_simulation/monitor_per_multiple_steps.ipynb - - -Model Training --------------- - -This tutorial shows how to train a dynamical system from data or task. - -.. toctree:: - :maxdepth: 1 - - tutorial_training/build_training_models.ipynb - tutorial_training/offline_training.ipynb - tutorial_training/online_training.ipynb - tutorial_training/bp_training.ipynb - tutorial_training/esn_introduction.ipynb - - -Model Analysis --------------- - .. toctree:: - :maxdepth: 1 + :maxdepth: 2 - tutorial_analysis/lowdim_analysis - tutorial_analysis/highdim_analysis - tutorial_analysis/decision_making_model + tutorial_math/index + tutorial_building/index + tutorial_simulation/index + tutorial_training/index + tutorial_analysis/index diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index 9c7daff55..f98527458 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -228,7 +228,7 @@ def __init__(self): ) def update(self, input): - spk = self.delay.at('delay') + spk = self.delay.at('I') self.E(self.syn1(spk[:3200])) self.I(self.syn2(spk[3200:])) self.delay(self.N(input)) diff --git a/examples/dynamics_training/integrator_rnn.py b/examples/dynamics_training/integrator_rnn.py index fc36845e6..aeaf0c412 100644 --- a/examples/dynamics_training/integrator_rnn.py +++ b/examples/dynamics_training/integrator_rnn.py @@ -30,7 +30,7 @@ def train_data(): class RNN(bp.DynamicalSystem): def __init__(self, num_in, num_hidden): super(RNN, self).__init__() - self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True) + self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) self.out = bp.layers.Dense(num_hidden, 1) def update(self, x): @@ -49,7 +49,7 @@ def loss(predictions, targets, l2_reg=2e-4): # define optimizer -lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) +lr = bp.optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975) opt = bp.optim.Adam(lr=lr, eps=1e-1) # create a trainer diff --git a/requirements-dev-raw.txt b/requirements-dev-raw.txt new file mode 100644 index 000000000..99361efa9 --- /dev/null +++ b/requirements-dev-raw.txt @@ -0,0 +1,12 @@ +numpy +jax +jaxlib +matplotlib +msgpack +tqdm +pathos + + +# test requirements +pytest +absl-py diff --git a/requirements-dev.txt b/requirements-dev.txt index 0e475e83d..98398ae2d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,4 @@ numpy -numba brainpylib jax jaxlib @@ -7,7 +6,9 @@ matplotlib msgpack tqdm pathos -taichi==1.7.0 +taichi +numba + # test requirements pytest diff --git a/requirements.txt b/requirements.txt index 02fdebe83..ab5665e73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ numpy jax tqdm -numba -taichi==1.7.0 diff --git a/setup.py b/setup.py index d7fd45e38..885bbf57b 100644 --- a/setup.py +++ b/setup.py @@ -56,8 +56,8 @@ author='BrainPy Team', author_email='chao.brain@qq.com', packages=packages, - python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'numba', 'taichi==1.7.0'], + python_requires='>=3.9', + install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", @@ -68,11 +68,12 @@ 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', ], extras_require={ - 'cpu': ['jaxlib>=0.4.13', 'brainpylib'], - 'cuda': ['jax[cuda]', 'brainpylib-cu12x'], - 'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'], - 'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'], - 'tpu': ['jax[tpu]'], + 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'], + 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib-cu11x', 'numba', 'taichi==1.7.0'], + 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib-cu12x', 'numba', 'taichi==1.7.0'], + 'tpu': ['jaxlib[tpu]', 'numba',], + 'cpu_mini': ['jaxlib>=0.4.13'], + 'cuda_mini': ['jaxlib[cuda12_pip]'], }, keywords=('computational neuroscience, ' 'brain-inspired computation, ' @@ -89,6 +90,7 @@ 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Topic :: Scientific/Engineering :: Bio-Informatics', From 8460a1d3ff09fe6e9a776ad65398c8a2799ee496 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 15:17:30 +0800 Subject: [PATCH 7/8] `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` (#639) * `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` * add doc * upgrade * upgrade * try to fix * update CI --- .github/workflows/CI.yml | 192 ++--- .../synapses/tests/test_abstract_synapses.py | 254 +++--- brainpy/_src/math/environment.py | 19 +- brainpy/_src/math/jitconn/_event_matvec.py | 742 +++++++++++++++++- brainpy/_src/math/object_transform/naming.py | 7 +- 5 files changed, 927 insertions(+), 287 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 95bd8eafd..d29b07ebc 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -11,6 +11,12 @@ on: branches: - '**' # matches every branch + +permissions: + contents: read # to fetch code + actions: write # to cancel previous workflows + + #on: # push: # branches: [ master ] @@ -27,6 +33,10 @@ jobs: python-version: [ "3.9", "3.10", "3.11"] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -35,16 +45,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | cd brainpy @@ -82,40 +85,6 @@ jobs: pytest _src/ -# test_linux_py37: -# runs-on: ubuntu-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# - test_macos: runs-on: macos-latest strategy: @@ -124,6 +93,10 @@ jobs: python-version: ["3.9", "3.10", "3.11"] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -132,16 +105,40 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install - - name: Lint with flake8 + - name: Test with pytest run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + cd brainpy + pytest -n auto --tb=short _src/ + + + test_windows: + strategy: + fail-fast: false + matrix: + os: [ win-2019-16core ] + arch: [ AMD64 ] + python-version: ["3.9", "3.10", "3.11"] + runs-on: ${{ matrix.os }} + + steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements-dev.txt + pip uninstall brainpy -y + python setup.py install - name: Test with pytest run: | cd brainpy @@ -178,104 +175,3 @@ jobs: cd brainpy pytest _src/ -# test_macos_py37: -# runs-on: macos-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: [ "3.7" ] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# - - -# test_windows: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.9", "3.10", "3.11"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install -r requirements-dev.txt -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd brainpy -# pytest _src/ - - -# test_windows_py37: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install numpy>=1.21.0 -# python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver -# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz -# python -m pip install -r requirements-dev.txt -# python -m pip install tqdm brainpylib -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index c3936f685..6db945ff2 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -1,130 +1,124 @@ -# -*- coding: utf-8 -*- - -import pytest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm -from brainpy._src.dynold.synapses import abstract_models -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) - - -class Test_Abstract_Synapse(parameterized.TestCase): - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_all2all_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(bp.synapses, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_one2one_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - comp_type=['sparse', 'dense'], - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_sparse_synapse(self, comp_type, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - post_ref_key=[None, 'refractory'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_delta_synapse(self, post_ref_key, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5, ref_var=True) - post_neu = bp.neurons.LIF(3, ref_var=True) - syn = bp.synapses.Delta(pre_neu, post_neu, - conn=bp.conn.All2All(), - post_ref_key=post_ref_key, - stp=stp, ) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - pre_expected_shape = (100, 5) - post_expected_shape = (100, 3) - if isinstance(mode, bm.BatchingMode): - pre_expected_shape = (mode.batch_size,) + pre_expected_shape - post_expected_shape = (mode.batch_size,) + post_expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + + +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dynold.synapses import abstract_models + + +class Test_Abstract_Synapse(parameterized.TestCase): + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_all2all_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(bp.synapses, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, monitors=['pre.V', 'syn.g', 'post.V'], inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_one2one_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + comp_type=['sparse', 'dense'], + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_sparse_synapse(self, comp_type, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + post_ref_key=[None, 'refractory'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_delta_synapse(self, post_ref_key, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5, ref_var=True) + post_neu = bp.neurons.LIF(3, ref_var=True) + syn = bp.synapses.Delta(pre_neu, post_neu, + conn=bp.conn.All2All(), + post_ref_key=post_ref_key, + stp=stp, ) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + pre_expected_shape = (100, 5) + post_expected_shape = (100, 3) + if isinstance(mode, bm.BatchingMode): + pre_expected_shape = (mode.batch_size,) + pre_expected_shape + post_expected_shape = (mode.batch_size,) + post_expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) + bm.clear_buffer_memory() \ No newline at end of file diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 668f837c0..7827dfed3 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -2,6 +2,7 @@ import functools +import gc import inspect import os import re @@ -16,6 +17,7 @@ from . import modes from . import scales from . import defaults +from .object_transform import naming from brainpy._src.dependency_check import import_taichi ti = import_taichi(error_if_not_found=False) @@ -681,7 +683,9 @@ def set_host_device_count(n): def clear_buffer_memory( platform: str = None, array: bool = True, - compilation: bool = False + transform: bool = True, + compilation: bool = False, + object_name: bool = False, ): """Clear all on-device buffers. @@ -698,9 +702,13 @@ def clear_buffer_memory( platform: str The device to clear its memory. array: bool - Clear all buffer array. + Clear all buffer array. Default is True. compilation: bool - Clear compilation cache. + Clear compilation cache. Default is False. + transform: bool + Clear transform cache. Default is True. + object_name: bool + Clear name cache. Default is True. """ if array: @@ -708,6 +716,11 @@ def clear_buffer_memory( buf.delete() if compilation: jax.clear_caches() + if transform: + naming.clear_stack_cache() + if object_name: + naming.clear_name_cache() + gc.collect() def disable_gpu_memory_preallocation(release_memory: bool = True): diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 976b72b96..ac62bbfaf 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -45,9 +45,747 @@ def event_mv_prob_homo( if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +### BRAINPYLIB ### + +def event_mv_prob_homo_brainpylib( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + weight = jnp.atleast_1d(jnp.asarray(weight)) + conn_prob = jnp.atleast_1d(jnp.asarray(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + r = event_mv_prob_homo_p.bind(events, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + return r + + +event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform_brainpylib( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_prob = jnp.atleast_1d(as_jax(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + return event_mv_prob_uniform_p.bind(events, + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal_brainpylib( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_prob = jnp.atleast_1d(as_jax(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + return event_mv_prob_normal_p.bind(events, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ + + +def _event_matvec_prob_homo_abstract( + events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _event_matvec_prob_homo_cpu_translation( + c, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_homo' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type + + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + weight, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(weight), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_homo_gpu_translation( + c, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_homo_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1], ) + + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type + + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, weight, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(weight), + c.get_shape(clen), + c.get_shape(seed)), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_homo_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, weight, clen, seed = primals + event_dot, weight_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_homo_p.bind(events, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + assert type(weight_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + if type(weight_dot) is ad.Zero: + if type(event_dot) is ad.Zero: + raise ValueError + dr = mv_prob_homo_p.bind(event_dot, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + elif type(event_dot) is ad.Zero: + dr = mv_prob_homo_p.bind(events, + weight_dot, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + else: + dr = mv_prob_homo_p.bind(event_dot, + weight_dot, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, dr + + +def _event_matvec_prob_homo_transpose( + ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(weight) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_homo_p.bind(ct[0], + weight, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, weight, clen, seed + + +event_mv_prob_homo_p = Primitive('event_mv_prob_homo') +event_mv_prob_homo_p.multiple_results = True +event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) +event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation +ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp +ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose +register_general_batching(event_mv_prob_homo_p) + + +def _event_matvec_prob_uniform_abstract( + events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + _w_low_dtype = _get_dtype(w_low) + _w_high_dtype = _get_dtype(w_low) + assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' + assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' + assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if w_low.ndim != 1: + raise ValueError('w_low must be a 1D scalar.') + if w_high.ndim != 1: + raise ValueError('w_high must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('clen must be a 1D scalar.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + + if not isinstance(transpose, bool): + raise ValueError('transpose must be a boolean value.') + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be a boolean value.') + assert w_low.dtype == w_high.dtype + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _event_matvec_prob_uniform_cpu_translation( + c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + w_low, + w_high, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_low), + c.get_shape(w_high), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_uniform_gpu_translation( + c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1]) + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, w_low, w_high, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_low), + c.get_shape(w_high), + c.get_shape(clen), + c.get_shape(seed),), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_uniform_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, w_low, w_high, clen, seed = primals + events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_uniform_p.bind(events, + w_low, + w_high, + clen, + seed, + shape=shape, + outdim_parallel=outdim_parallel, + transpose=transpose) + assert type(w_low_dot) is ad.Zero + assert type(w_high_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + r_dot = mv_prob_uniform_p.bind(events_dot, + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, r_dot + + +def _event_matvec_prob_uniform_transpose( + ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(w_low) is not ad.UndefinedPrimal + assert type(w_high) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_uniform_p.bind(ct[0], + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, w_low, w_high, clen, seed + + +event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') +event_mv_prob_uniform_p.multiple_results = True +event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) +event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation +register_general_batching(event_mv_prob_uniform_p) +ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp +ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose + + +def _event_matvec_prob_normal_abstract( + events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + _w_mu_dtype = _get_dtype(w_mu) + _w_sigma_dtype = _get_dtype(w_sigma) + assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' + assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if w_mu.ndim != 1: + raise ValueError('w_mu should be a 1D scalar.') + if w_sigma.ndim != 1: + raise ValueError('w_sigma should be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('clen should be a 1D scalar.') + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + assert w_mu.dtype == w_sigma.dtype + + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be a boolean value.') + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be a boolean value.') + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _get_types(event_shape): + event_type = event_shape.element_type() + if event_type == jnp.bool_: + event_type = b'_bool' + out_dtype = dtypes.canonicalize_dtype(float) + elif event_type == jnp.float32: + event_type = b'_float' + out_dtype = event_shape.element_type() + elif event_type == jnp.float64: + event_type = b'_double' + out_dtype = event_shape.element_type() + else: + raise TypeError + + if out_dtype == jnp.float32: + type_name = b'_float' + elif out_dtype == jnp.float64: + type_name = b'_double' + else: + raise TypeError + + return out_dtype, event_type, type_name + + +def _event_matvec_prob_normal_cpu_translation( + c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_normal' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + w_mu, + w_sigma, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_mu), + c.get_shape(w_sigma), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_normal_gpu_translation( + c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_normal_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1]) + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, w_mu, w_sigma, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_mu), + c.get_shape(w_sigma), + c.get_shape(clen), + c.get_shape(seed)), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_normal_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, w_mu, w_sigma, clen, seed = primals + events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_normal_p.bind(events, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + assert type(w_mu_dot) is ad.Zero + assert type(w_sigma_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + r_dot = mv_prob_normal_p.bind(events_dot, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, r_dot + + +def _event_matvec_prob_normal_transpose( + ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(w_mu) is not ad.UndefinedPrimal + assert type(w_sigma) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_normal_p.bind(ct[0], + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, w_mu, w_sigma, clen, seed + + +event_mv_prob_normal_p = Primitive('event_mv_prob_normal') +event_mv_prob_normal_p.multiple_results = True +event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) +event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation +register_general_batching(event_mv_prob_normal_p) +ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp +ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose + + +### TAICHI ### + +def event_mv_prob_homo_taichi( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ events = as_jax(events) - if isinstance(weight, float): weight = as_jax(weight) - weight = jnp.atleast_1d(as_jax(weight)) + weight = as_jax(weight) + if jnp.ndim(weight) < 1: + weight = jnp.expand_dims(weight, axis=0) conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) if seed is None: diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1c8ca6ef9..6326929c4 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import gc import warnings from brainpy import errors @@ -11,6 +11,7 @@ _name2id = dict() _typed_names = {} +_fun2stack = dict() def check_name_uniqueness(name, obj): @@ -49,9 +50,6 @@ def clear_name_cache(ignore_warn=False): warnings.warn(f'All named models and their ids are cleared.', UserWarning) -_fun2stack = dict() - - def cache_stack(func, stack): _fun2stack[func] = stack @@ -59,6 +57,7 @@ def cache_stack(func, stack): def clear_stack_cache(): for k in tuple(_fun2stack.keys()): del _fun2stack[k] + gc.collect() def get_stack_cache(func): From 3826c548939516015ff138e37566052c5472ccba Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 15:19:00 +0800 Subject: [PATCH 8/8] add `brainpy.math.surrogate..Surrogate` (#638) * add `brainpy.math.Surrogate` * fix * Update _event_matvec.py * Update _event_matvec.py * Update _event_matvec.py --------- Co-authored-by: He Sichao <1310722434@qq.com> --- brainpy/_src/math/surrogate/_one_input_new.py | 27 +++++++++++++++++-- brainpy/math/surrogate.py | 3 ++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/math/surrogate/_one_input_new.py b/brainpy/_src/math/surrogate/_one_input_new.py index 64c7280d0..bfffd88f5 100644 --- a/brainpy/_src/math/surrogate/_one_input_new.py +++ b/brainpy/_src/math/surrogate/_one_input_new.py @@ -90,7 +90,30 @@ def _as_jax(x): class Surrogate(object): - """The base surrograte gradient function.""" + """The base surrograte gradient function. + + To customize a surrogate gradient function, you can inherit this class and + implement the `surrogate_fun` and `surrogate_grad` methods. + + Examples + -------- + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import jax.numpy as jnp + + >>> class MySurrogate(bm.Surrogate): + ... def __init__(self, alpha=1.): + ... super().__init__() + ... self.alpha = alpha + ... + ... def surrogate_fun(self, x): + ... return jnp.sin(x) * self.alpha + ... + ... def surrogate_grad(self, x): + ... return jnp.cos(x) * self.alpha + + """ def __call__(self, x): x = _as_jax(x) @@ -123,7 +146,7 @@ def __init__(self, alpha: float = 4.): self.alpha = alpha def surrogate_fun(self, x): - return sci.special.expit(x) + return sci.special.expit(self.alpha * x) def surrogate_grad(self, x): sgax = sci.special.expit(x * self.alpha) diff --git a/brainpy/math/surrogate.py b/brainpy/math/surrogate.py index 0121bddec..bf7897435 100644 --- a/brainpy/math/surrogate.py +++ b/brainpy/math/surrogate.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- - from brainpy._src.math.surrogate._one_input_new import ( + Surrogate, + Sigmoid, sigmoid as sigmoid,