From 64d8f54c34b8cdacd32274ca74aca7bfb639e0f6 Mon Sep 17 00:00:00 2001 From: yunhui <38786521+CloudyDory@users.noreply.github.com> Date: Thu, 22 Feb 2024 22:56:23 +0800 Subject: [PATCH 01/21] Fix delayvar not correct in concat mode (#632) --- brainpy/_src/math/delayvars.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 390e04dd..676e4286 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -473,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = concatenate([expand_dims(value, 0), self.data[1:]], axis=0) + self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) else: self.data[:] = value From a579e411cd857ba1230aa7f816e7e40a07c49fd2 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Fri, 1 Mar 2024 10:39:33 +0800 Subject: [PATCH 02/21] [dependency] remove hard dependency of `taichi` and `numba` (#635) * try to remove hard dependency with taichi and numba * [math] Update operator selection strategy for csr matvec * [math] Remove old test case of event csr matvec and csr matvec * [dependency] remove all numba and taichi dependency * fix * Update * Update test_taichi_clean_cache.py * Update CI and remove taichi, numba from requirements * Resolve conflicts * Revert "Merge branch 'master' into dependency-optimize" This reverts commit c54c82214f6f463d9cf0cb9b53d469fe9521edd1, reversing changes made to 76202d5baa427d5698f566a1260a66d8e4f51c8e. * upgrade dependency * fix * update * update * update doc and dependency * update dependency --------- Co-authored-by: Chaoming Wang --- .github/workflows/CI.yml | 62 + README.md | 24 +- brainpy/_src/connect/random_conn.py | 2617 ++++++++------- brainpy/_src/dependency_check.py | 222 +- brainpy/_src/dnn/conv.py | 11 +- brainpy/_src/dnn/linear.py | 405 +-- brainpy/_src/dnn/tests/test_activation.py | 3 +- brainpy/_src/dnn/tests/test_conv_layers.py | 11 +- brainpy/_src/dnn/tests/test_function.py | 6 +- brainpy/_src/dnn/tests/test_linear.py | 441 +-- brainpy/_src/dnn/tests/test_mode.py | 1608 +++++----- brainpy/_src/dnn/tests/test_normalization.py | 5 +- brainpy/_src/dnn/tests/test_pooling_layers.py | 2 +- .../_src/dyn/projections/tests/test_STDP.py | 247 +- .../_src/dyn/projections/tests/test_aligns.py | 883 +++--- .../synapses/tests/test_abstract_synapses.py | 256 +- .../tests/test_biological_synapses.py | 211 +- brainpy/_src/math/defaults.py | 18 +- brainpy/_src/math/delayvars.py | 5 +- brainpy/_src/math/environment.py | 23 +- brainpy/_src/math/event/__init__.py | 2 - brainpy/_src/math/event/_csr_matvec.py | 1271 ++------ brainpy/_src/math/event/_info_collection.py | 198 -- .../tests/event_info_VS_jax_operators.py | 275 -- .../_src/math/event/tests/test_event_csrmv.py | 7 + .../math/event/tests/test_event_csrmv_old.py | 324 -- brainpy/_src/math/event/tests/test_info.py | 62 - .../_src/math/event/tests/test_info_gpu.py | 14 - brainpy/_src/math/index_tricks.py | 305 -- brainpy/_src/math/jitconn/__init__.py | 5 +- brainpy/_src/math/jitconn/_event_matvec.py | 2821 ++++++----------- brainpy/_src/math/jitconn/_matvec.py | 2175 ++++--------- .../math/jitconn/tests/test_event_matvec.py | 6 + .../_src/math/jitconn/tests/test_matvec.py | 5 + .../_src/math/object_transform/autograd.py | 45 +- brainpy/_src/math/object_transform/base.py | 4 +- .../_src/math/object_transform/controls.py | 136 +- brainpy/_src/math/object_transform/jit.py | 69 +- brainpy/_src/math/object_transform/naming.py | 3 +- .../_src/math/object_transform/parallels.py | 460 +++ brainpy/_src/math/object_transform/tools.py | 75 +- .../_src/math/object_transform/variables.py | 45 +- brainpy/_src/math/op_register/__init__.py | 15 +- brainpy/_src/math/op_register/base.py | 30 +- .../op_register/numba_approach/__init__.py | 14 +- .../numba_approach/cpu_translation.py | 298 +- brainpy/_src/math/op_register/numba_based.py | 13 +- .../math/op_register/tests/test_ad_support.py | 7 +- .../op_register/tests/test_numba_based.py | 7 +- .../op_register/tests/test_taichi_based.py | 7 +- .../tests/test_taichi_clean_cache.py | 110 +- brainpy/_src/math/sparse/__init__.py | 5 +- brainpy/_src/math/sparse/_bsr_mm.py | 100 +- brainpy/_src/math/sparse/_csr_mv.py | 896 ++---- brainpy/_src/math/sparse/_utils.py | 3 +- brainpy/_src/math/sparse/tests/test_csrmv.py | 6 +- .../_src/math/sparse/tests/test_csrmv_old.py | 352 -- brainpy/_src/math/tests/test_tifunc.py | 246 +- brainpy/_src/math/tifunc.py | 513 ++- brainpy/_src/tests/test_dyn_runner.py | 267 +- brainpy/_src/tools/functions.py | 192 -- brainpy/_src/tools/progress.py | 519 +++ brainpy/_src/tools/tests/test_functions.py | 24 - brainpy/errors.py | 11 +- brainpy/math/__init__.py | 205 +- brainpy/math/compat_pytorch.py | 2 +- brainpy/math/event.py | 2 - brainpy/math/jitconn.py | 20 +- brainpy/math/oo_transform.py | 4 - brainpy/math/op_register.py | 26 +- brainpy/math/sparse.py | 7 +- brainpy/math/tifunc.py | 51 +- brainpy/tools.py | 5 - docs/advanced_tutorials.rst | 51 +- docs/apis/brainpy.math.oo_transform.rst | 1 - docs/quickstart/installation.rst | 262 +- docs/toolboxes.rst | 38 +- docs/tutorials.rst | 77 +- examples/dynamics_simulation/ei_nets.py | 2 +- examples/dynamics_training/integrator_rnn.py | 4 +- requirements-dev-raw.txt | 12 + requirements-dev.txt | 5 +- requirements.txt | 2 - setup.py | 16 +- 84 files changed, 7956 insertions(+), 11838 deletions(-) delete mode 100644 brainpy/_src/math/event/_info_collection.py delete mode 100644 brainpy/_src/math/event/tests/event_info_VS_jax_operators.py delete mode 100644 brainpy/_src/math/event/tests/test_event_csrmv_old.py delete mode 100644 brainpy/_src/math/event/tests/test_info.py delete mode 100644 brainpy/_src/math/event/tests/test_info_gpu.py delete mode 100644 brainpy/_src/math/index_tricks.py create mode 100644 brainpy/_src/math/object_transform/parallels.py delete mode 100644 brainpy/_src/math/sparse/tests/test_csrmv_old.py delete mode 100644 brainpy/_src/tools/functions.py create mode 100644 brainpy/_src/tools/progress.py delete mode 100644 brainpy/_src/tools/tests/test_functions.py create mode 100644 requirements-dev-raw.txt diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 84aa028e..95bd8eaf 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -50,6 +50,37 @@ jobs: cd brainpy pytest _src/ + test_linux_with_taichi_numba: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ "3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest taichi numba + if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi + pip uninstall brainpy -y + python setup.py install + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + cd brainpy + pytest _src/ + # test_linux_py37: # runs-on: ubuntu-latest @@ -116,6 +147,37 @@ jobs: cd brainpy pytest _src/ + test_macos_with_taichi_numba: + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest taichi numba + if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi + pip uninstall brainpy -y + python setup.py install + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + cd brainpy + pytest _src/ + # test_macos_py37: # runs-on: macos-latest # strategy: diff --git a/README.md b/README.md index 6d2ee4bf..a7fe0b72 100644 --- a/README.md +++ b/README.md @@ -25,29 +25,7 @@ BrainPy is a flexible, efficient, and extensible framework for computational neu ## Installation -BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. Install the latest version of BrainPy: - -```bash -$ pip install brainpy -U -``` - -In addition, many customized operators in BrainPy are implemented in ``brainpylib``. -Install the latest version of `brainpylib` by: - -```bash -# CPU installation for Linux, macOS and Windows -$ pip install --upgrade brainpylib -``` - -```bash -# CUDA 12 installation for Linux only -$ pip install --upgrade brainpylib-cu12x -``` - -```bash -# CUDA 11 installation for Linux only -$ pip install --upgrade brainpylib-cu11x -``` +BrainPy is based on Python (>=3.8) and can be installed on Linux (Ubuntu 16.04 or later), macOS (10.12 or later), and Windows platforms. For detailed installation instructions, please refer to the documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 1f5b1db6..0e4ee769 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -1,1372 +1,1245 @@ -# -*- coding: utf-8 -*- -from functools import partial -from typing import Optional - -from jax import vmap, jit, numpy as jnp -import numpy as np -from numba import njit - -import brainpy.math as bm -from brainpy.errors import ConnectorError -from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed -from brainpy._src.tools.package import SUPPORT_NUMBA -from .base import * - -__all__ = [ - 'FixedProb', - 'FixedPreNum', - 'FixedPostNum', - 'FixedTotalNum', - 'GaussianProb', - 'ProbDist', - - 'SmallWorld', - 'ScaleFreeBA', - 'ScaleFreeBADual', - 'PowerLaw', -] - - -class FixedProb(TwoEndConnector): - """Connect the post-synaptic neurons with fixed probability. - - Parameters - ---------- - prob: float - The conn probability. - pre_ratio: float - The ratio of pre-synaptic neurons to connect. - include_self : bool - Whether create (i, i) conn? - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - seed : optional, int - Seed the random generator. - """ - - def __init__(self, - prob, - pre_ratio=1., - include_self=True, - allow_multi_conn=False, - seed=None, - **kwargs): - super(FixedProb, self).__init__(**kwargs) - assert 0. <= prob <= 1. - assert 0. <= pre_ratio <= 1. - self.prob = prob - self.pre_ratio = pre_ratio - self.include_self = include_self - self.seed = format_seed(seed) - self.allow_multi_conn = allow_multi_conn - self._jaxrand = bm.random.default_rng(self.seed) - self._nprand = np.random.RandomState(self.seed) - - def __repr__(self): - return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' - f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' - f'seed={self.seed})') - - def _iii(self): - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - - if self.pre_ratio < 1.: - pre_num_to_select = int(self.pre_num * self.pre_ratio) - pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) - else: - pre_num_to_select = self.pre_num - pre_ids = jnp.arange(self.pre_num) - - post_num_total = self.post_num - post_num_to_select = int(self.post_num * self.prob) - - if self.allow_multi_conn: - selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self._nprand.randint(0, int(1e8))) - else: - rng = self._nprand - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(pre_num_to_select): - posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) - return posts - - selected_post_ids = jnp.asarray(single_conn()) - return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) - - def build_coo(self): - _, post_num_to_select, selected_post_ids, pre_ids = self._iii() - selected_post_ids = selected_post_ids.flatten() - selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) - if not self.include_self: - true_ids = selected_pre_ids != selected_post_ids - selected_pre_ids = selected_pre_ids[true_ids] - selected_post_ids = selected_post_ids[true_ids] - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def build_csr(self): - pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() - pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select - if not self.include_self: - true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) - pre_nums -= jnp.sum(true_ids, axis=1) - selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_post_ids = selected_post_ids.flatten() - selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) - return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) - - def build_mat(self): - if self.pre_ratio < 1.: - pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state - else: - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) - mat = bm.asarray(mat) - if not self.include_self: - bm.fill_diagonal(mat, False) - return mat.astype(MAT_DTYPE) - - -class FixedTotalNum(TwoEndConnector): - """Connect the synaptic neurons with fixed total number. - - Parameters - ---------- - num : float,int - The conn total number. - allow_multi_conn : bool, optional - Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. - seed: int, optional - The random number seed. - """ - - def __init__(self, - num, - allow_multi_conn=False, - seed=None, **kwargs): - super().__init__(**kwargs) - if isinstance(num, int): - assert num >= 0, '"num" must be a non-negative integer.' - elif isinstance(num, float): - assert 0. <= num <= 1., '"num" must be in [0., 1.).' - else: - raise ConnectorError(f'Unknown type: {type(num)}') - self.num = num - self.seed = format_seed(seed) - self.allow_multi_conn = allow_multi_conn - self.rng = bm.random.RandomState(self.seed) - - def build_coo(self): - mat_element_num = self.pre_num * self.post_num - if self.num > mat_element_num: - raise ConnectorError(f'"num" must be smaller than "all2all num", ' - f'but got {self.num} > {mat_element_num}') - if self.allow_multi_conn: - selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) - selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) - else: - index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) - selected_pre_ids = index // self.post_num - selected_post_ids = index % self.post_num - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def __repr__(self): - return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})' - - -class FixedNum(TwoEndConnector): - def __init__(self, - num, - include_self=True, - allow_multi_conn=False, - seed=None, - **kwargs): - super(FixedNum, self).__init__(**kwargs) - if isinstance(num, int): - assert num >= 0, '"num" must be a non-negative integer.' - elif isinstance(num, float): - assert 0. <= num <= 1., '"num" must be in [0., 1.).' - else: - raise ConnectorError(f'Unknown type: {type(num)}') - self.num = num - self.seed = format_seed(seed) - self.include_self = include_self - self.allow_multi_conn = allow_multi_conn - self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed) - - def __repr__(self): - return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' - - -class FixedPreNum(FixedNum): - """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. - - Parameters - ---------- - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - """ - - def build_coo(self): - if isinstance(self.num, int) and self.num > self.pre_num: - raise ConnectorError(f'"num" must be smaller than "pre_num", ' - f'but got {self.num} > {self.pre_num}') - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num - pre_num_total = self.pre_num - post_num_total = self.post_num - - if self.allow_multi_conn: - selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self.rng.randint(0, int(1e8))) - else: - rng = self.rng - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((post_num_total, pre_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(post_num_total): - posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) - return posts - - selected_pre_ids = jnp.asarray(single_conn()) - - post_nums = jnp.ones((post_num_total,), dtype=get_idx_type()) * pre_num_to_select - if not self.include_self: - true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1)) - post_nums -= jnp.sum(true_ids, axis=1) - selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_pre_ids = selected_pre_ids.flatten() - selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums) - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - -class FixedPostNum(FixedNum): - """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. - - Parameters - ---------- - num : float, int - The conn probability (if "num" is float) or the fixed number of - connectivity (if "num" is int). - include_self : bool - Whether create (i, i) conn ? - seed : None, int - Seed the random generator. - allow_multi_conn: bool - Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? - - .. versionadded:: 2.2.3.2 - - """ - - def _ii(self): - if isinstance(self.num, int) and self.num > self.post_num: - raise ConnectorError(f'"num" must be smaller than "post_num", ' - f'but got {self.num} > {self.post_num}') - if (not self.include_self) and (self.pre_num != self.post_num): - raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' - f'But `include_self` is set to True.') - post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num - pre_num_to_select = self.pre_num - pre_ids = jnp.arange(self.pre_num) - post_num_total = self.post_num - - if self.allow_multi_conn: - selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,)) - - else: - if SUPPORT_NUMBA: - rng = np.random - numba_seed(self.rng.randint(0, int(1e8))) - else: - rng = self.rng - - @numba_jit # (parallel=True, nogil=True) - def single_conn(): - posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) - for i in numba_range(pre_num_to_select): - posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) - return posts - - selected_post_ids = jnp.asarray(single_conn()) - return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) - - def build_coo(self): - _, post_num_to_select, selected_post_ids, pre_ids = self._ii() - selected_post_ids = selected_post_ids.flatten() - selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) - if not self.include_self: - true_ids = selected_pre_ids != selected_post_ids - selected_pre_ids = selected_pre_ids[true_ids] - selected_post_ids = selected_post_ids[true_ids] - return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) - - def build_csr(self): - pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() - pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select - if not self.include_self: - true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) - pre_nums -= jnp.sum(true_ids, axis=1) - selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] - else: - selected_post_ids = selected_post_ids.flatten() - selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) - return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) - -@jit -@partial(vmap, in_axes=(0, None, None)) -def gaussian_prob_dist_cal1(i_value, post_values, sigma): - dists = jnp.abs(i_value - post_values) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) - -@jit -@partial(vmap, in_axes=(0, None, None, None)) -def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma): - dists = jnp.abs(i_value - post_values) - dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) - exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) - return bm.asarray(exp_dists) - - -class GaussianProb(OneEndConnector): - r"""Builds a Gaussian connectivity pattern within a population of neurons, - where the connection probability decay according to the gaussian function. - - Specifically, for any pair of neurons :math:`(i, j)`, - - .. math:: - - p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) - - where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. - - Parameters - ---------- - sigma : float - Width of the Gaussian function. - encoding_values : optional, list, tuple, int, float - The value ranges to encode for neurons at each axis. - - - If `values` is not provided, the neuron only encodes each positional - information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is - the index in the high-dimensional space. - - If `values` is a single tuple/list of int/float, neurons at each dimension - will encode the same range of values. For example, ``values=(0, np.pi)``, - neurons at each dimension will encode a continuous value space ``[0, np.pi]``. - - If `values` is a tuple/list of list/tuple, it means the value space will be - different for each dimension. For example, ``values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))``. - - periodic_boundary : bool - Whether the neuron encode the value space with the periodic boundary. - normalize : bool - Whether normalize the connection probability . - include_self : bool - Whether create the connection at the same position. - seed : int - The random seed. - """ - - def __init__( - self, - sigma: float, - encoding_values: Optional[np.ndarray] = None, - normalize: bool = True, - include_self: bool = True, - periodic_boundary: bool = False, - seed: int = None, - **kwargs - ): - super(GaussianProb, self).__init__(**kwargs) - self.sigma = sigma - self.encoding_values = encoding_values - self.normalize = normalize - self.include_self = include_self - self.periodic_boundary = periodic_boundary - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - - def __repr__(self): - return (f'{self.__class__.__name__}(sigma={self.sigma}, ' - f'normalize={self.normalize}, ' - f'periodic_boundary={self.periodic_boundary}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_mat(self, isOptimized=True): - self.rng = np.random.RandomState(self.seed) - # value range to encode - if self.encoding_values is None: - value_ranges = tuple([(0, s) for s in self.pre_size]) - elif isinstance(self.encoding_values, (tuple, list)): - if len(self.encoding_values) == 0: - raise ConnectorError(f'encoding_values has a length of 0.') - elif isinstance(self.encoding_values[0], (int, float)): - assert len(self.encoding_values) == 2 - assert self.encoding_values[0] < self.encoding_values[1] - value_ranges = tuple([self.encoding_values for _ in self.pre_size]) - elif isinstance(self.encoding_values[0], (tuple, list)): - if len(self.encoding_values) != len(self.pre_size): - raise ConnectorError(f'The network size has {len(self.pre_size)} dimensions, while ' - f'the encoded values provided only has {len(self.encoding_values)}-D. ' - f'Error in {str(self)}.') - for v in self.encoding_values: - assert isinstance(v[0], (int, float)) - assert len(v) == 2 - value_ranges = tuple(self.encoding_values) - else: - raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') - else: - raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') - - # values - values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] - # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) - post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) - value_sizes = np.array([v[1] - v[0] for v in value_ranges]) - if value_sizes.ndim < post_values.ndim: - value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - - # probability of connections - if isOptimized: - i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) - for i in range(self.pre_num): - list_index = i - # values for node i - i_coordinate = tuple() - for s in self.pre_size[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - i_value_list[list_index] = i_value - - if self.periodic_boundary: - prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) - else: - prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) - else: - prob_mat = [] - for i in range(self.pre_num): - # values for node i - i_coordinate = tuple() - for s in self.pre_size[:-1]: - i, pos = divmod(i, s) - i_coordinate += (pos,) - i_coordinate += (i,) - i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) - if i_value.ndim < post_values.ndim: - i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) - # distances - dists = np.abs(i_value - post_values) - if self.periodic_boundary: - dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) - exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) - prob_mat.append(exp_dists) - prob_mat = np.stack(prob_mat) - - if self.normalize: - prob_mat /= prob_mat.max() - - # connectivity - conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) - if not self.include_self: - np.fill_diagonal(conn_mat, False) - return conn_mat - - -class SmallWorld(TwoEndConnector): - """Build a Watts–Strogatz small-world graph. - - Parameters - ---------- - num_neighbor : int - Each node is joined with its `k` nearest neighbors in a ring - topology. - prob : float - The probability of rewiring each edge - directed : bool - Whether the graph is a directed graph. - include_self : bool - Whether include the node self. - - Notes - ----- - First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is - joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors - if :math:`num\_neighbor` is odd). Then shortcuts are created by replacing some edges as - follows: for each edge :math:`(u, v)` in the underlying ":math:`num\_node`-ring with - :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new - edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. - - References - ---------- - .. [1] Duncan J. Watts and Steven H. Strogatz, - Collective dynamics of small-world networks, - Nature, 393, pp. 440--442, 1998. - """ - - def __init__( - self, - num_neighbor, - prob, - directed=False, - include_self=False, - seed=None, - **kwargs - ): - super(SmallWorld, self).__init__(**kwargs) - self.prob = prob - self.directed = directed - self.num_neighbor = num_neighbor - self.include_self = include_self - - self.seed = format_seed(seed) - self.rng = np.random.RandomState(seed=self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _smallworld_rewire(i, all_j): - if rng.random(1) < prob: - non_connected = np.where(np.logical_not(all_j))[0] - if len(non_connected) <= 1: - return -1 - # Enforce no self-loops or multiple edges - w = rng.choice(non_connected) - while (not include_self) and w == i: - # non_connected.remove(w) - w = rng.choice(non_connected) - return w - else: - return -1 - - self._connect = numba_jit(_smallworld_rewire) - - def __repr__(self): - return (f'{self.__class__.__name__}(prob={self.prob}, ' - f'directed={self.directed}, ' - f'num_neighbor={self.num_neighbor}, ' - f'include_self={self.include_self}, ' - f'seed={self.seed})') - - def build_conn(self): - assert self.pre_size == self.post_size - - # seed - self.seed = self.rng.randint(1, int(1e7)) - numba_seed(self.seed) - - if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1): - num_node = self.pre_num - - if self.num_neighbor > num_node: - raise ConnectorError("num_neighbor > num_node, choose smaller num_neighbor or larger num_node") - # If k == n, the graph is complete not Watts-Strogatz - if self.num_neighbor == num_node: - conn = np.ones((num_node, num_node), dtype=MAT_DTYPE) - else: - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - nodes = np.array(list(range(num_node))) # nodes are labeled 0 to n-1 - # connect each node to k/2 neighbors - for j in range(1, self.num_neighbor // 2 + 1): - targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list - conn[nodes, targets] = True - conn[targets, nodes] = True - - # rewire edges from each node - # loop over all nodes in order (label) and neighbors in order (distance) - # no self loops or multiple edges allowed - for j in range(1, self.num_neighbor // 2 + 1): # outer loop is neighbors - targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list - if self.directed: - # inner loop in node order - for u, v in zip(nodes, targets): - w = self._connect(prob=self.prob, i=u, all_j=conn[u]) - if w != -1: - conn[u, v] = False - conn[u, w] = True - w = self._connect(prob=self.prob, i=u, all_j=conn[:, u]) - if w != -1: - conn[v, u] = False - conn[w, u] = True - else: - # inner loop in node order - for u, v in zip(nodes, targets): - w = self._connect(i=u, all_j=conn[u]) - if w != -1: - conn[u, v] = False - conn[v, u] = False - conn[u, w] = True - conn[w, u] = True - # conn = np.asarray(conn, dtype=MAT_DTYPE) - else: - raise ConnectorError('Currently only support 1D ring connection.') - - return 'mat', conn - - -# def _random_subset(seq, m, rng): -# """Return m unique elements from seq. -# -# This differs from random.sample which can return repeated -# elements if seq holds repeated elements. -# -# Note: rng is a random.Random or numpy.random.RandomState instance. -# """ -# targets = set() -# while len(targets) < m: -# x = rng.choice(seq) -# targets.add(x) -# return targets - - -class ScaleFreeBA(TwoEndConnector): - """Build a random graph according to the Barabási–Albert preferential - attachment model. - - A graph of :math:`num\_node` nodes is grown by attaching new nodes each with - :math:`m` edges that are preferentially attached to existing nodes - with high degree. - - Parameters - ---------- - m : int - Number of edges to attach from a new node to existing nodes - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Raises - ------ - ConnectorError - If `m` does not satisfy ``1 <= m < n``. - - References - ---------- - .. [1] A. L. Barabási and R. Albert "Emergence of scaling in - random networks", Science 286, pp 509-512, 1999. - """ - - def __init__(self, m, directed=False, seed=None, **kwargs): - super(ScaleFreeBA, self).__init__(**kwargs) - self.m = m - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m={self.m}, ' - f'directed={self.directed}, ' - f'seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - - num_node = self.pre_num - if self.m < 1 or self.m >= num_node: - raise ConnectorError(f"Barabási–Albert network must have m >= 1 and " - f"m < n, while m = {self.m} and n = {num_node}") - - # Add m initial nodes (m0 in barabasi-speak) - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - # Target nodes for new edges - targets = list(range(self.m)) - # List of existing nodes, with nodes repeated once for each adjacent edge - - if not isOptimized: - repeated_nodes = [] - # Start adding the other n-m nodes. The first node is m. - source = self.m - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * self.m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes.extend(targets) - # And the new node "source" has m edges to add to the list. - repeated_nodes.extend([source] * self.m) - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(np.asarray(repeated_nodes), self.m)) - source += 1 - return conn - - # List of existing nodes, with nodes repeated once for each adjacent edge - # Preallocate repeated_nodes as a numpy array - repeated_nodes = np.empty(2 * num_node * self.m, dtype=int) - size_repeated_nodes = 0 - # Start adding the other n-m nodes. The first node is m. - source = self.m - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * self.m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets - size_repeated_nodes += self.m - # And the new node "source" has m edges to add to the list. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source - size_repeated_nodes += self.m - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m)) - source += 1 - - return conn - - -class ScaleFreeBADual(TwoEndConnector): - r"""Build a random graph according to the dual Barabási–Albert preferential - attachment model. - - A graph of :math::`num\_node` nodes is grown by attaching new nodes each with either $m_1$ - edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that - are preferentially attached to existing nodes with high degree. - - Parameters - ---------- - m1 : int - Number of edges to attach from a new node to existing nodes with probability :math:`p` - m2 : int - Number of edges to attach from a new node to existing nodes with probability :math:`1-p` - p : float - The probability of attaching :math:`m\_1` edges (as opposed to :math:`m\_2` edges) - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Raises - ------ - ConnectorError - If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. - - References - ---------- - .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. - """ - - def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs): - super(ScaleFreeBADual, self).__init__(**kwargs) - self.m1 = m1 - self.m2 = m2 - self.p = p - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, ' - f'p={self.p}, directed={self.directed}, seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - - num_node = self.pre_num - if self.m1 < 1 or self.m1 >= num_node: - raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, " - f"while m1 = {self.m1} and num_node = {num_node}.") - if self.m2 < 1 or self.m2 >= num_node: - raise ConnectorError(f"Dual Barabási–Albert network must have m2 >= 1 and m2 < num_node, " - f"while m2 = {self.m2} and num_node = {num_node}.") - if self.p < 0 or self.p > 1: - raise ConnectorError(f"Dual Barabási–Albert network must have 0 <= p <= 1, while p = {self.p}") - - # Add max(m1,m2) initial nodes (m0 in barabasi-speak) - conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) - - if not isOptimized: - # List of existing nodes, with nodes repeated once for each adjacent edge - repeated_nodes = [] - # Start adding the remaining nodes. - source = max(self.m1, self.m2) - # Pick which m to use first time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Target nodes for new edges - targets = list(range(m)) - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes.extend(targets) - # And the new node "source" has m edges to add to the list. - repeated_nodes.extend([source] * m) - # Pick which m to use next time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(np.asarray(repeated_nodes), m)) - source += 1 - return conn - - # List of existing nodes, with nodes repeated once for each adjacent edge - # Preallocate repeated_nodes as a numpy array - repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int) - size_repeated_nodes = 0 - # Start adding the remaining nodes. - source = max(self.m1, self.m2) - # Pick which m to use first time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Target nodes for new edges - targets = list(range(m)) - while source < num_node: - # Add edges to m nodes from the source. - origins = [source] * m - conn[origins, targets] = True - if not self.directed: - conn[targets, origins] = True - # Add one node to the list for each new edge just created. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets - size_repeated_nodes += m - # And the new node "source" has m edges to add to the list. - repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source - size_repeated_nodes += m - # Pick which m to use next time (m1 or m2) - m = self.m1 if self.rng.random() < self.p else self.m2 - # Now choose m unique nodes from the existing nodes - # Pick uniformly from repeated_nodes (preferential attachment) - targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m)) - source += 1 - - return conn - - -class PowerLaw(TwoEndConnector): - """Holme and Kim algorithm for growing graphs with powerlaw - degree distribution and approximate average clustering. - - Parameters - ---------- - m : int - the number of random edges to add for each new node - p : float, - Probability of adding a triangle after adding a random edge - seed : integer, random_state, or None (default) - Indicator of random number generation state. - - Notes - ----- - The average clustering has a hard time getting above a certain - cutoff that depends on :math:`m`. This cutoff is often quite low. The - transitivity (fraction of triangles to possible triangles) seems to - decrease with network size. - - It is essentially the Barabási–Albert (BA) growth model with an - extra step that each random edge is followed by a chance of - making an edge to one of its neighbors too (and thus a triangle). - - This algorithm improves on BA in the sense that it enables a - higher average clustering to be attained if desired. - - It seems possible to have a disconnected graph with this algorithm - since the initial :math:`m` nodes may not be all linked to a new node - on the first iteration like the BA model. - - Raises - ------ - ConnectorError - If :math:`m` does not satisfy :math:`1 <= m <= n` or :math:`p` does not - satisfy :math:`0 <= p <= 1`. - - References - ---------- - .. [1] P. Holme and B. J. Kim, - "Growing scale-free networks with tunable clustering", - Phys. Rev. E, 65, 026107, 2002. - """ - - def __init__(self, m: int, p: float, directed=False, seed=None, **kwargs): - super(PowerLaw, self).__init__(**kwargs) - self.m = m - self.p = p - if self.p > 1 or self.p < 0: - raise ConnectorError(f"p must be in [0,1], while p={self.p}") - self.directed = directed - self.seed = format_seed(seed) - self.rng = np.random.RandomState(self.seed) - rng = np.random if SUPPORT_NUMBA else self.rng - - def _random_subset(seq, m): - targets = set() - while len(targets) < m: - x = rng.choice(seq) - targets.add(x) - return targets - - self._connect = numba_jit(_random_subset) - - def __repr__(self): - return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})') - - def build_mat(self, isOptimized=True): - assert self.pre_num == self.post_num - # seed - self.rng = np.random.RandomState(self.seed) - numba_seed(self.seed) - num_node = self.pre_num - if self.m < 1 or num_node < self.m: - raise ConnectorError(f"Must have m>1 and m 1 else p.flatten() for p in pre_ids]) - size = np.prod(pre_size) - - for i in range(size): - pre_pos = np.asarray([p[i] for p in pre_ids]) - pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim) - connected_pres.extend(pres) - connected_posts.extend(posts) - return np.asarray(connected_pres), np.asarray(connected_posts) +# -*- coding: utf-8 -*- + +from functools import partial +from typing import Optional + +from jax import vmap, jit, numpy as jnp +import numpy as np + +import brainpy.math as bm +from brainpy.errors import ConnectorError +from brainpy.tools import numba_seed, numba_jit, numba_range, format_seed +from brainpy._src.tools.package import SUPPORT_NUMBA +from .base import * + + +__all__ = [ + 'FixedProb', + 'FixedPreNum', + 'FixedPostNum', + 'FixedTotalNum', + 'GaussianProb', + 'ProbDist', + + 'SmallWorld', + 'ScaleFreeBA', + 'ScaleFreeBADual', + 'PowerLaw', +] + + +class FixedProb(TwoEndConnector): + """Connect the post-synaptic neurons with fixed probability. + + Parameters + ---------- + prob: float + The conn probability. + pre_ratio: float + The ratio of pre-synaptic neurons to connect. + include_self : bool + Whether create (i, i) conn? + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + seed : optional, int + Seed the random generator. + """ + + def __init__(self, + prob, + pre_ratio=1., + include_self=True, + allow_multi_conn=False, + seed=None, + **kwargs): + super(FixedProb, self).__init__(**kwargs) + assert 0. <= prob <= 1. + assert 0. <= pre_ratio <= 1. + self.prob = prob + self.pre_ratio = pre_ratio + self.include_self = include_self + self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn + self._jaxrand = bm.random.default_rng(self.seed) + self._nprand = np.random.RandomState(self.seed) + + def __repr__(self): + return (f'{self.__class__.__name__}(prob={self.prob}, pre_ratio={self.pre_ratio}, ' + f'include_self={self.include_self}, allow_multi_conn={self.allow_multi_conn}, ' + f'seed={self.seed})') + + def _iii(self): + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + + if self.pre_ratio < 1.: + pre_num_to_select = int(self.pre_num * self.pre_ratio) + pre_ids = self._jaxrand.choice(self.pre_num, size=(pre_num_to_select,), replace=False) + else: + pre_num_to_select = self.pre_num + pre_ids = jnp.arange(self.pre_num) + + post_num_total = self.post_num + post_num_to_select = int(self.post_num * self.prob) + + if self.allow_multi_conn: + selected_post_ids = self._jaxrand.randint(0, post_num_total, (pre_num_to_select, post_num_to_select)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self._nprand.randint(0, int(1e8))) + else: + rng = self._nprand + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = jnp.asarray(single_conn()) + return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) + + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._iii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._iii() + pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) + pre_nums -= jnp.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) + return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) + + def build_mat(self): + if self.pre_ratio < 1.: + pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state + else: + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) + mat = bm.asarray(mat) + if not self.include_self: + bm.fill_diagonal(mat, False) + return mat.astype(MAT_DTYPE) + + +class FixedTotalNum(TwoEndConnector): + """Connect the synaptic neurons with fixed total number. + + Parameters + ---------- + num : float,int + The conn total number. + allow_multi_conn : bool, optional + Whether allow one pre-synaptic neuron connects to multiple post-synaptic neurons. + seed: int, optional + The random number seed. + """ + + def __init__(self, + num, + allow_multi_conn=False, + seed=None, **kwargs): + super().__init__(**kwargs) + if isinstance(num, int): + assert num >= 0, '"num" must be a non-negative integer.' + elif isinstance(num, float): + assert 0. <= num <= 1., '"num" must be in [0., 1.).' + else: + raise ConnectorError(f'Unknown type: {type(num)}') + self.num = num + self.seed = format_seed(seed) + self.allow_multi_conn = allow_multi_conn + self.rng = bm.random.RandomState(self.seed) + + def build_coo(self): + mat_element_num = self.pre_num * self.post_num + if self.num > mat_element_num: + raise ConnectorError(f'"num" must be smaller than "all2all num", ' + f'but got {self.num} > {mat_element_num}') + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, self.pre_num, (self.num,)) + selected_post_ids = self.rng.randint(0, self.post_num, (self.num,)) + else: + index = self.rng.choice(mat_element_num, size=(self.num,), replace=False) + selected_pre_ids = index // self.post_num + selected_post_ids = index % self.post_num + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def __repr__(self): + return f'{self.__class__.__name__}(num={self.num}, seed={self.seed})' + + +class FixedNum(TwoEndConnector): + def __init__(self, + num, + include_self=True, + allow_multi_conn=False, + seed=None, + **kwargs): + super(FixedNum, self).__init__(**kwargs) + if isinstance(num, int): + assert num >= 0, '"num" must be a non-negative integer.' + elif isinstance(num, float): + assert 0. <= num <= 1., '"num" must be in [0., 1.).' + else: + raise ConnectorError(f'Unknown type: {type(num)}') + self.num = num + self.seed = format_seed(seed) + self.include_self = include_self + self.allow_multi_conn = allow_multi_conn + self.rng = bm.random.RandomState(self.seed) if allow_multi_conn else np.random.RandomState(self.seed) + + def __repr__(self): + return f'{self.__class__.__name__}(num={self.num}, include_self={self.include_self}, seed={self.seed})' + + +class FixedPreNum(FixedNum): + """Connect a fixed number pf pre-synaptic neurons for each post-synaptic neuron. + + Parameters + ---------- + num : float, int + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). + include_self : bool + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + """ + + def build_coo(self): + if isinstance(self.num, int) and self.num > self.pre_num: + raise ConnectorError(f'"num" must be smaller than "pre_num", ' + f'but got {self.num} > {self.pre_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + pre_num_to_select = int(self.pre_num * self.num) if isinstance(self.num, float) else self.num + pre_num_total = self.pre_num + post_num_total = self.post_num + + if self.allow_multi_conn: + selected_pre_ids = self.rng.randint(0, pre_num_total, (post_num_total, pre_num_to_select,)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((post_num_total, pre_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(post_num_total): + posts[i] = rng.choice(pre_num_total, pre_num_to_select, replace=False) + return posts + + selected_pre_ids = jnp.asarray(single_conn()) + + post_nums = jnp.ones((post_num_total,), dtype=get_idx_type()) * pre_num_to_select + if not self.include_self: + true_ids = selected_pre_ids == jnp.reshape(jnp.arange(pre_num_total), (-1, 1)) + post_nums -= jnp.sum(true_ids, axis=1) + selected_pre_ids = selected_pre_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_pre_ids = selected_pre_ids.flatten() + selected_post_ids = jnp.repeat(jnp.arange(post_num_total), post_nums) + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + +class FixedPostNum(FixedNum): + """Connect the fixed number of post-synaptic neurons for each pre-synaptic neuron. + + Parameters + ---------- + num : float, int + The conn probability (if "num" is float) or the fixed number of + connectivity (if "num" is int). + include_self : bool + Whether create (i, i) conn ? + seed : None, int + Seed the random generator. + allow_multi_conn: bool + Allow one pre-synaptic neuron connects to multiple post-synaptic neurons? + + .. versionadded:: 2.2.3.2 + + """ + + def _ii(self): + if isinstance(self.num, int) and self.num > self.post_num: + raise ConnectorError(f'"num" must be smaller than "post_num", ' + f'but got {self.num} > {self.post_num}') + if (not self.include_self) and (self.pre_num != self.post_num): + raise ConnectorError(f'We found pre_num != post_num ({self.pre_num} != {self.post_num}). ' + f'But `include_self` is set to True.') + post_num_to_select = int(self.post_num * self.num) if isinstance(self.num, float) else self.num + pre_num_to_select = self.pre_num + pre_ids = jnp.arange(self.pre_num) + post_num_total = self.post_num + + if self.allow_multi_conn: + selected_post_ids = self.rng.randint(0, post_num_total, (pre_num_to_select, post_num_to_select,)) + + else: + if SUPPORT_NUMBA: + rng = np.random + numba_seed(self.rng.randint(0, int(1e8))) + else: + rng = self.rng + + @numba_jit # (parallel=True, nogil=True) + def single_conn(): + posts = np.zeros((pre_num_to_select, post_num_to_select), dtype=IDX_DTYPE) + for i in numba_range(pre_num_to_select): + posts[i] = rng.choice(post_num_total, post_num_to_select, replace=False) + return posts + + selected_post_ids = jnp.asarray(single_conn()) + return pre_num_to_select, post_num_to_select, bm.as_jax(selected_post_ids), bm.as_jax(pre_ids) + + def build_coo(self): + _, post_num_to_select, selected_post_ids, pre_ids = self._ii() + selected_post_ids = selected_post_ids.flatten() + selected_pre_ids = jnp.repeat(pre_ids, post_num_to_select) + if not self.include_self: + true_ids = selected_pre_ids != selected_post_ids + selected_pre_ids = selected_pre_ids[true_ids] + selected_post_ids = selected_post_ids[true_ids] + return selected_pre_ids.astype(get_idx_type()), selected_post_ids.astype(get_idx_type()) + + def build_csr(self): + pre_num_to_select, post_num_to_select, selected_post_ids, pre_ids = self._ii() + pre_nums = jnp.ones(pre_num_to_select) * post_num_to_select + if not self.include_self: + true_ids = selected_post_ids == jnp.reshape(pre_ids, (-1, 1)) + pre_nums -= jnp.sum(true_ids, axis=1) + selected_post_ids = selected_post_ids.flatten()[jnp.logical_not(true_ids).flatten()] + else: + selected_post_ids = selected_post_ids.flatten() + selected_pre_inptr = jnp.cumsum(jnp.concatenate([jnp.zeros(1), pre_nums])) + return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) + + +@jit +@partial(vmap, in_axes=(0, None, None)) +def gaussian_prob_dist_cal1(i_value, post_values, sigma): + dists = jnp.abs(i_value - post_values) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) + + +@jit +@partial(vmap, in_axes=(0, None, None, None)) +def gaussian_prob_dist_cal2(i_value, post_values, value_sizes, sigma): + dists = jnp.abs(i_value - post_values) + dists = jnp.where(dists > (value_sizes / 2), value_sizes - dists, dists) + exp_dists = jnp.exp(-(jnp.sqrt(jnp.sum(dists ** 2, axis=0)) / sigma) ** 2 / 2) + return bm.asarray(exp_dists) + + +class GaussianProb(OneEndConnector): + r"""Builds a Gaussian connectivity pattern within a population of neurons, + where the connection probability decay according to the gaussian function. + + Specifically, for any pair of neurons :math:`(i, j)`, + + .. math:: + + p(i, j)=\exp(-\frac{\sum_{k=1}^n |v_k^i - v_k^j|^2 }{2\sigma^2}) + + where :math:`v_k^i` is the :math:`i`-th neuron's encoded value at dimension :math:`k`. + + Parameters + ---------- + sigma : float + Width of the Gaussian function. + encoding_values : optional, list, tuple, int, float + The value ranges to encode for neurons at each axis. + + - If `values` is not provided, the neuron only encodes each positional + information, i.e., :math:`(i, j, k, ...)`, where :math:`i, j, k` is + the index in the high-dimensional space. + - If `values` is a single tuple/list of int/float, neurons at each dimension + will encode the same range of values. For example, ``values=(0, np.pi)``, + neurons at each dimension will encode a continuous value space ``[0, np.pi]``. + - If `values` is a tuple/list of list/tuple, it means the value space will be + different for each dimension. For example, ``values=((-np.pi, np.pi), (10, 20), (0, 2 * np.pi))``. + + periodic_boundary : bool + Whether the neuron encode the value space with the periodic boundary. + normalize : bool + Whether normalize the connection probability . + include_self : bool + Whether create the connection at the same position. + seed : int + The random seed. + """ + + def __init__( + self, + sigma: float, + encoding_values: Optional[np.ndarray] = None, + normalize: bool = True, + include_self: bool = True, + periodic_boundary: bool = False, + seed: int = None, + **kwargs + ): + super(GaussianProb, self).__init__(**kwargs) + self.sigma = sigma + self.encoding_values = encoding_values + self.normalize = normalize + self.include_self = include_self + self.periodic_boundary = periodic_boundary + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + + def __repr__(self): + return (f'{self.__class__.__name__}(sigma={self.sigma}, ' + f'normalize={self.normalize}, ' + f'periodic_boundary={self.periodic_boundary}, ' + f'include_self={self.include_self}, ' + f'seed={self.seed})') + + def build_mat(self, isOptimized=True): + self.rng = np.random.RandomState(self.seed) + # value range to encode + if self.encoding_values is None: + value_ranges = tuple([(0, s) for s in self.pre_size]) + elif isinstance(self.encoding_values, (tuple, list)): + if len(self.encoding_values) == 0: + raise ConnectorError(f'encoding_values has a length of 0.') + elif isinstance(self.encoding_values[0], (int, float)): + assert len(self.encoding_values) == 2 + assert self.encoding_values[0] < self.encoding_values[1] + value_ranges = tuple([self.encoding_values for _ in self.pre_size]) + elif isinstance(self.encoding_values[0], (tuple, list)): + if len(self.encoding_values) != len(self.pre_size): + raise ConnectorError(f'The network size has {len(self.pre_size)} dimensions, while ' + f'the encoded values provided only has {len(self.encoding_values)}-D. ' + f'Error in {str(self)}.') + for v in self.encoding_values: + assert isinstance(v[0], (int, float)) + assert len(v) == 2 + value_ranges = tuple(self.encoding_values) + else: + raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') + else: + raise ConnectorError(f'Unsupported encoding values: {self.encoding_values}') + + # values + values = [np.linspace(vs[0], vs[1], n + 1)[:n] for vs, n in zip(value_ranges, self.pre_size)] + # post_values = np.stack([v.flatten() for v in np.meshgrid(*values, indexing='ij')]) + post_values = np.stack([v.flatten() for v in np.meshgrid(*values)]) + value_sizes = np.array([v[1] - v[0] for v in value_ranges]) + if value_sizes.ndim < post_values.ndim: + value_sizes = np.expand_dims(value_sizes, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + + # probability of connections + if isOptimized: + i_value_list = np.zeros(shape=(self.pre_num, len(self.pre_size), 1)) + for i in range(self.pre_num): + list_index = i + # values for node i + i_coordinate = tuple() + for s in self.pre_size[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + i_value_list[list_index] = i_value + + if self.periodic_boundary: + prob_mat = gaussian_prob_dist_cal2(i_value_list, post_values, value_sizes, self.sigma) + else: + prob_mat = gaussian_prob_dist_cal1(i_value_list, post_values, self.sigma) + else: + prob_mat = [] + for i in range(self.pre_num): + # values for node i + i_coordinate = tuple() + for s in self.pre_size[:-1]: + i, pos = divmod(i, s) + i_coordinate += (pos,) + i_coordinate += (i,) + i_value = np.array([values[i][c] for i, c in enumerate(i_coordinate)]) + if i_value.ndim < post_values.ndim: + i_value = np.expand_dims(i_value, axis=tuple([i + 1 for i in range(post_values.ndim - 1)])) + # distances + dists = np.abs(i_value - post_values) + if self.periodic_boundary: + dists = np.where(dists > value_sizes / 2, value_sizes - dists, dists) + exp_dists = np.exp(-(np.linalg.norm(dists, axis=0) / self.sigma) ** 2 / 2) + prob_mat.append(exp_dists) + prob_mat = np.stack(prob_mat) + + if self.normalize: + prob_mat /= prob_mat.max() + + # connectivity + conn_mat = np.asarray(prob_mat) >= self.rng.random(prob_mat.shape) + if not self.include_self: + np.fill_diagonal(conn_mat, False) + return conn_mat + + +class SmallWorld(TwoEndConnector): + """Build a Watts–Strogatz small-world graph. + + Parameters + ---------- + num_neighbor : int + Each node is joined with its `k` nearest neighbors in a ring + topology. + prob : float + The probability of rewiring each edge + directed : bool + Whether the graph is a directed graph. + include_self : bool + Whether include the node self. + + Notes + ----- + First create a ring over :math:`num\_node` nodes [1]_. Then each node in the ring is + joined to its :math:`num\_neighbor` nearest neighbors (or :math:`num\_neighbor - 1` neighbors + if :math:`num\_neighbor` is odd). Then shortcuts are created by replacing some edges as + follows: for each edge :math:`(u, v)` in the underlying ":math:`num\_node`-ring with + :math:`num\_neighbor` nearest neighbors" with probability :math:`prob` replace it with a new + edge :math:`(u, w)` with uniformly random choice of existing node :math:`w`. + + References + ---------- + .. [1] Duncan J. Watts and Steven H. Strogatz, + Collective dynamics of small-world networks, + Nature, 393, pp. 440--442, 1998. + """ + + def __init__( + self, + num_neighbor, + prob, + directed=False, + include_self=False, + seed=None, + **kwargs + ): + super(SmallWorld, self).__init__(**kwargs) + self.prob = prob + self.directed = directed + self.num_neighbor = num_neighbor + self.include_self = include_self + + self.seed = format_seed(seed) + self.rng = np.random.RandomState(seed=self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _smallworld_rewire(i, all_j): + if rng.random(1) < prob: + non_connected = np.where(np.logical_not(all_j))[0] + if len(non_connected) <= 1: + return -1 + # Enforce no self-loops or multiple edges + w = rng.choice(non_connected) + while (not include_self) and w == i: + # non_connected.remove(w) + w = rng.choice(non_connected) + return w + else: + return -1 + + self._connect = numba_jit(_smallworld_rewire) + + def __repr__(self): + return (f'{self.__class__.__name__}(prob={self.prob}, ' + f'directed={self.directed}, ' + f'num_neighbor={self.num_neighbor}, ' + f'include_self={self.include_self}, ' + f'seed={self.seed})') + + def build_conn(self): + assert self.pre_size == self.post_size + + # seed + self.seed = self.rng.randint(1, int(1e7)) + numba_seed(self.seed) + + if isinstance(self.pre_size, int) or (isinstance(self.pre_size, (tuple, list)) and len(self.pre_size) == 1): + num_node = self.pre_num + + if self.num_neighbor > num_node: + raise ConnectorError("num_neighbor > num_node, choose smaller num_neighbor or larger num_node") + # If k == n, the graph is complete not Watts-Strogatz + if self.num_neighbor == num_node: + conn = np.ones((num_node, num_node), dtype=MAT_DTYPE) + else: + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + nodes = np.array(list(range(num_node))) # nodes are labeled 0 to n-1 + # connect each node to k/2 neighbors + for j in range(1, self.num_neighbor // 2 + 1): + targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list + conn[nodes, targets] = True + conn[targets, nodes] = True + + # rewire edges from each node + # loop over all nodes in order (label) and neighbors in order (distance) + # no self loops or multiple edges allowed + for j in range(1, self.num_neighbor // 2 + 1): # outer loop is neighbors + targets = np.concatenate([nodes[j:], nodes[0:j]]) # first j nodes are now last in list + if self.directed: + # inner loop in node order + for u, v in zip(nodes, targets): + w = self._connect(prob=self.prob, i=u, all_j=conn[u]) + if w != -1: + conn[u, v] = False + conn[u, w] = True + w = self._connect(prob=self.prob, i=u, all_j=conn[:, u]) + if w != -1: + conn[v, u] = False + conn[w, u] = True + else: + # inner loop in node order + for u, v in zip(nodes, targets): + w = self._connect(i=u, all_j=conn[u]) + if w != -1: + conn[u, v] = False + conn[v, u] = False + conn[u, w] = True + conn[w, u] = True + # conn = np.asarray(conn, dtype=MAT_DTYPE) + else: + raise ConnectorError('Currently only support 1D ring connection.') + + return 'mat', conn + + +# def _random_subset(seq, m, rng): +# """Return m unique elements from seq. +# +# This differs from random.sample which can return repeated +# elements if seq holds repeated elements. +# +# Note: rng is a random.Random or numpy.random.RandomState instance. +# """ +# targets = set() +# while len(targets) < m: +# x = rng.choice(seq) +# targets.add(x) +# return targets + + +class ScaleFreeBA(TwoEndConnector): + """Build a random graph according to the Barabási–Albert preferential + attachment model. + + A graph of :math:`num\_node` nodes is grown by attaching new nodes each with + :math:`m` edges that are preferentially attached to existing nodes + with high degree. + + Parameters + ---------- + m : int + Number of edges to attach from a new node to existing nodes + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Raises + ------ + ConnectorError + If `m` does not satisfy ``1 <= m < n``. + + References + ---------- + .. [1] A. L. Barabási and R. Albert "Emergence of scaling in + random networks", Science 286, pp 509-512, 1999. + """ + + def __init__(self, m, directed=False, seed=None, **kwargs): + super(ScaleFreeBA, self).__init__(**kwargs) + self.m = m + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m={self.m}, ' + f'directed={self.directed}, ' + f'seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + num_node = self.pre_num + if self.m < 1 or self.m >= num_node: + raise ConnectorError(f"Barabási–Albert network must have m >= 1 and " + f"m < n, while m = {self.m} and n = {num_node}") + + # Add m initial nodes (m0 in barabasi-speak) + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + # Target nodes for new edges + targets = list(range(self.m)) + # List of existing nodes, with nodes repeated once for each adjacent edge + + if not isOptimized: + repeated_nodes = [] + # Start adding the other n-m nodes. The first node is m. + source = self.m + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * self.m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes.extend(targets) + # And the new node "source" has m edges to add to the list. + repeated_nodes.extend([source] * self.m) + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(np.asarray(repeated_nodes), self.m)) + source += 1 + return conn + + # List of existing nodes, with nodes repeated once for each adjacent edge + # Preallocate repeated_nodes as a numpy array + repeated_nodes = np.empty(2 * num_node * self.m, dtype=int) + size_repeated_nodes = 0 + # Start adding the other n-m nodes. The first node is m. + source = self.m + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * self.m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = targets + size_repeated_nodes += self.m + # And the new node "source" has m edges to add to the list. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + self.m] = source + size_repeated_nodes += self.m + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(repeated_nodes[:size_repeated_nodes], self.m)) + source += 1 + + return conn + + +class ScaleFreeBADual(TwoEndConnector): + r"""Build a random graph according to the dual Barabási–Albert preferential + attachment model. + + A graph of :math::`num\_node` nodes is grown by attaching new nodes each with either $m_1$ + edges (with probability :math:`p`) or :math:`m_2` edges (with probability :math:`1-p`) that + are preferentially attached to existing nodes with high degree. + + Parameters + ---------- + m1 : int + Number of edges to attach from a new node to existing nodes with probability :math:`p` + m2 : int + Number of edges to attach from a new node to existing nodes with probability :math:`1-p` + p : float + The probability of attaching :math:`m\_1` edges (as opposed to :math:`m\_2` edges) + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Raises + ------ + ConnectorError + If `m1` and `m2` do not satisfy ``1 <= m1,m2 < n`` or `p` does not satisfy ``0 <= p <= 1``. + + References + ---------- + .. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538. + """ + + def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs): + super(ScaleFreeBADual, self).__init__(**kwargs) + self.m1 = m1 + self.m2 = m2 + self.p = p + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m1={self.m1}, m2={self.m2}, ' + f'p={self.p}, directed={self.directed}, seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + + num_node = self.pre_num + if self.m1 < 1 or self.m1 >= num_node: + raise ConnectorError(f"Dual Barabási–Albert network must have m1 >= 1 and m1 < num_node, " + f"while m1 = {self.m1} and num_node = {num_node}.") + if self.m2 < 1 or self.m2 >= num_node: + raise ConnectorError(f"Dual Barabási–Albert network must have m2 >= 1 and m2 < num_node, " + f"while m2 = {self.m2} and num_node = {num_node}.") + if self.p < 0 or self.p > 1: + raise ConnectorError(f"Dual Barabási–Albert network must have 0 <= p <= 1, while p = {self.p}") + + # Add max(m1,m2) initial nodes (m0 in barabasi-speak) + conn = np.zeros((num_node, num_node), dtype=MAT_DTYPE) + + if not isOptimized: + # List of existing nodes, with nodes repeated once for each adjacent edge + repeated_nodes = [] + # Start adding the remaining nodes. + source = max(self.m1, self.m2) + # Pick which m to use first time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Target nodes for new edges + targets = list(range(m)) + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes.extend(targets) + # And the new node "source" has m edges to add to the list. + repeated_nodes.extend([source] * m) + # Pick which m to use next time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(np.asarray(repeated_nodes), m)) + source += 1 + return conn + + # List of existing nodes, with nodes repeated once for each adjacent edge + # Preallocate repeated_nodes as a numpy array + repeated_nodes = np.empty(2 * num_node * max(self.m1, self.m2), dtype=int) + size_repeated_nodes = 0 + # Start adding the remaining nodes. + source = max(self.m1, self.m2) + # Pick which m to use first time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Target nodes for new edges + targets = list(range(m)) + while source < num_node: + # Add edges to m nodes from the source. + origins = [source] * m + conn[origins, targets] = True + if not self.directed: + conn[targets, origins] = True + # Add one node to the list for each new edge just created. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = targets + size_repeated_nodes += m + # And the new node "source" has m edges to add to the list. + repeated_nodes[size_repeated_nodes:size_repeated_nodes + m] = source + size_repeated_nodes += m + # Pick which m to use next time (m1 or m2) + m = self.m1 if self.rng.random() < self.p else self.m2 + # Now choose m unique nodes from the existing nodes + # Pick uniformly from repeated_nodes (preferential attachment) + targets = list(self._connect(repeated_nodes[:size_repeated_nodes], m)) + source += 1 + + return conn + + +class PowerLaw(TwoEndConnector): + """Holme and Kim algorithm for growing graphs with powerlaw + degree distribution and approximate average clustering. + + Parameters + ---------- + m : int + the number of random edges to add for each new node + p : float, + Probability of adding a triangle after adding a random edge + seed : integer, random_state, or None (default) + Indicator of random number generation state. + + Notes + ----- + The average clustering has a hard time getting above a certain + cutoff that depends on :math:`m`. This cutoff is often quite low. The + transitivity (fraction of triangles to possible triangles) seems to + decrease with network size. + + It is essentially the Barabási–Albert (BA) growth model with an + extra step that each random edge is followed by a chance of + making an edge to one of its neighbors too (and thus a triangle). + + This algorithm improves on BA in the sense that it enables a + higher average clustering to be attained if desired. + + It seems possible to have a disconnected graph with this algorithm + since the initial :math:`m` nodes may not be all linked to a new node + on the first iteration like the BA model. + + Raises + ------ + ConnectorError + If :math:`m` does not satisfy :math:`1 <= m <= n` or :math:`p` does not + satisfy :math:`0 <= p <= 1`. + + References + ---------- + .. [1] P. Holme and B. J. Kim, + "Growing scale-free networks with tunable clustering", + Phys. Rev. E, 65, 026107, 2002. + """ + + def __init__(self, m: int, p: float, directed=False, seed=None, **kwargs): + super(PowerLaw, self).__init__(**kwargs) + self.m = m + self.p = p + if self.p > 1 or self.p < 0: + raise ConnectorError(f"p must be in [0,1], while p={self.p}") + self.directed = directed + self.seed = format_seed(seed) + self.rng = np.random.RandomState(self.seed) + rng = np.random if SUPPORT_NUMBA else self.rng + + def _random_subset(seq, m): + targets = set() + while len(targets) < m: + x = rng.choice(seq) + targets.add(x) + return targets + + self._connect = numba_jit(_random_subset) + + def __repr__(self): + return (f'{self.__class__.__name__}(m={self.m}, p={self.p}, directed={self.directed}, seed={self.seed})') + + def build_mat(self, isOptimized=True): + assert self.pre_num == self.post_num + # seed + self.rng = np.random.RandomState(self.seed) + numba_seed(self.seed) + num_node = self.pre_num + if self.m < 1 or num_node < self.m: + raise ConnectorError(f"Must have m>1 and m 1 else p.flatten() for p in pre_ids]) + size = np.prod(pre_size) + + for i in range(size): + pre_pos = np.asarray([p[i] for p in pre_ids]) + pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim) + connected_pres.extend(pres) + connected_posts.extend(posts) + return np.asarray(connected_pres), np.asarray(connected_posts) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 3bba20a7..b8bd6e99 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,87 +1,135 @@ -import os -import sys -from jax.lib import xla_client - -__all__ = [ - 'import_taichi', - 'import_brainpylib_cpu_ops', - 'import_brainpylib_gpu_ops', -] - -_minimal_brainpylib_version = '0.2.6' -_minimal_taichi_version = (1, 7, 0) - -taichi = None -brainpylib_cpu_ops = None -brainpylib_gpu_ops = None - -taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' - f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' - '> pip install taichi==1.7.0') -os.environ["TI_LOG_LEVEL"] = "error" - - -def import_taichi(): - global taichi - if taichi is None: - with open(os.devnull, 'w') as devnull: - old_stdout = sys.stdout - sys.stdout = devnull - try: - import taichi as taichi # noqa - except ModuleNotFoundError: - raise ModuleNotFoundError(taichi_install_info) - finally: - sys.stdout = old_stdout - - if taichi.__version__ != _minimal_taichi_version: - raise RuntimeError(taichi_install_info) - return taichi - - -def is_brainpylib_gpu_installed(): - return False if brainpylib_gpu_ops is None else True - - -def import_brainpylib_cpu_ops(): - global brainpylib_cpu_ops - if brainpylib_cpu_ops is None: - try: - from brainpylib import cpu_ops as brainpylib_cpu_ops - - for _name, _value in brainpylib_cpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="cpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_cpu_ops - - -def import_brainpylib_gpu_ops(): - global brainpylib_gpu_ops - if brainpylib_gpu_ops is None: - try: - from brainpylib import gpu_ops as brainpylib_gpu_ops - - for _name, _value in brainpylib_gpu_ops.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="gpu") - - import brainpylib - if brainpylib.__version__ < _minimal_brainpylib_version: - raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') - if hasattr(brainpylib, 'check_brainpy_version'): - brainpylib.check_brainpy_version() - - except ImportError: - raise ImportError('Please install GPU version of brainpylib. \n' - 'See https://brainpy.readthedocs.io for installation instructions.') - - return brainpylib_gpu_ops +import os +import sys + +from jax.lib import xla_client + +__all__ = [ + 'import_taichi', + 'raise_taichi_not_found', + 'import_numba', + 'raise_numba_not_found', + 'import_brainpylib_cpu_ops', + 'import_brainpylib_gpu_ops', +] + +_minimal_brainpylib_version = '0.2.6' +_minimal_taichi_version = (1, 7, 0) + +numba = None +taichi = None +brainpylib_cpu_ops = None +brainpylib_gpu_ops = None + +taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' + f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' + '> pip install taichi==1.7.0') +numba_install_info = ('We need numba. Please install numba by pip . \n' + '> pip install numba') +os.environ["TI_LOG_LEVEL"] = "error" + + +def import_taichi(error_if_not_found=True): + """Internal API to import taichi. + + If taichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global taichi + if taichi is None: + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + sys.stdout = devnull + try: + import taichi as taichi # noqa + except ModuleNotFoundError: + if error_if_not_found: + raise raise_taichi_not_found() + finally: + sys.stdout = old_stdout + + if taichi is None: + return None + if taichi.__version__ != _minimal_taichi_version: + raise RuntimeError(taichi_install_info) + return taichi + + +def raise_taichi_not_found(*args, **kwargs): + raise ModuleNotFoundError(taichi_install_info) + + +def import_numba(error_if_not_found=True): + """ + Internal API to import numba. + + If numba is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global numba + if numba is None: + try: + import numba as numba + except ModuleNotFoundError: + if error_if_not_found: + raise_numba_not_found() + else: + return None + return numba + + +def raise_numba_not_found(): + raise ModuleNotFoundError(numba_install_info) + + +def is_brainpylib_gpu_installed(): + return False if brainpylib_gpu_ops is None else True + + +def import_brainpylib_cpu_ops(): + """ + Internal API to import brainpylib cpu_ops. + """ + global brainpylib_cpu_ops + if brainpylib_cpu_ops is None: + try: + from brainpylib import cpu_ops as brainpylib_cpu_ops + + for _name, _value in brainpylib_cpu_ops.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="cpu") + + import brainpylib + if brainpylib.__version__ < _minimal_brainpylib_version: + raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') + if hasattr(brainpylib, 'check_brainpy_version'): + brainpylib.check_brainpy_version() + + except ImportError: + raise ImportError('Please install brainpylib. \n' + 'See https://brainpy.readthedocs.io for installation instructions.') + + return brainpylib_cpu_ops + + +def import_brainpylib_gpu_ops(): + """ + Internal API to import brainpylib gpu_ops. + """ + global brainpylib_gpu_ops + if brainpylib_gpu_ops is None: + try: + from brainpylib import gpu_ops as brainpylib_gpu_ops + + for _name, _value in brainpylib_gpu_ops.registrations().items(): + xla_client.register_custom_call_target(_name, _value, platform="gpu") + + import brainpylib + if brainpylib.__version__ < _minimal_brainpylib_version: + raise SystemError(f'This version of brainpy needs brainpylib >= {_minimal_brainpylib_version}.') + if hasattr(brainpylib, 'check_brainpy_version'): + brainpylib.check_brainpy_version() + + except ImportError: + raise ImportError('Please install GPU version of brainpylib. \n' + 'See https://brainpy.readthedocs.io for installation instructions.') + + return brainpylib_gpu_ops diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index e4b6e25d..deead1f3 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -160,7 +160,7 @@ def update(self, x): nonbatching = False if x.ndim == self.num_spatial_dims + 1: nonbatching = True - x = bm.unsqueeze(x, 0) + x = x.unsqueeze(0) w = self.w.value if self.mask is not None: try: @@ -190,9 +190,6 @@ def __repr__(self): class Conv1d(_GeneralConv): """One-dimensional convolution. - The input should a 2d array with the shape of ``[H, C]``, or - a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. - Parameters ---------- in_channels: int @@ -285,9 +282,6 @@ def _check_input_dim(self, x): class Conv2d(_GeneralConv): """Two-dimensional convolution. - The input should a 3d array with the shape of ``[H, W, C]``, or - a 4d array with the shape of ``[B, H, W, C]``. - Parameters ---------- in_channels: int @@ -381,9 +375,6 @@ def _check_input_dim(self, x): class Conv3d(_GeneralConv): """Three-dimensional convolution. - The input should a 3d array with the shape of ``[H, W, D, C]``, or - a 4d array with the shape of ``[B, H, W, D, C]``. - Parameters ---------- in_channels: int diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 539214d3..c524fb0b 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -6,22 +6,21 @@ import jax import jax.numpy as jnp -import numba import numpy as np from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share +from brainpy._src.dependency_check import import_taichi from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP -from brainpy._src.dependency_check import import_taichi from brainpy.check import is_initializer from brainpy.connect import csr2csc -from brainpy.errors import MathError +from brainpy.errors import MathError, PackageMissingError from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'Dense', 'Linear', @@ -239,140 +238,106 @@ def update(self, x): return x -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): -# out_w[:] = weight -# for i in numba.prange(spike.shape[0]): -# if spike[i]: -# out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) - -@ti.kernel -def _cpu_dense_on_pre(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 +if ti is not None: + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) + + @ti.kernel + def _dense_on_post( + old_w: ti.types.ndarray(ndim=2), + post_spike: ti.types.ndarray(ndim=1), + pre_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if post_spike[j]: + new_value = out_w[i, j] + pre_trace[i] if new_value < w_min0: out_w[i, j] = w_min0 elif new_value > w_max0: out_w[i, j] = w_max0 else: - out_w[i, j] = new_value - - -@ti.kernel -def _gpu_dense_on_pre(weight: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[1]): - new_value = out_w[i, j] + trace0 + out_w[i, j] = new_value + else: + out_w[i, j] = old_w[i, j] + + + dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post) + + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_dense_on_pre(weight, spike, trace, w_min, w_max, out_w): + # out_w[:] = weight + # for i in numba.prange(spike.shape[0]): + # if spike[i]: + # out_w[i] = np.clip(out_w[i] + trace, w_min, w_max) + + @ti.kernel + def _dense_on_pre( + old_w: ti.types.ndarray(ndim=2), + pre_spike: ti.types.ndarray(ndim=1), + post_trace: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + out_w: ti.types.ndarray(ndim=2) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre, num_post = out_w.shape + + for i, j in ti.ndrange(num_pre, num_post): + if pre_spike[i]: + new_value = out_w[i, j] + post_trace[j] if new_value < w_min0: out_w[i, j] = w_min0 elif new_value > w_max0: out_w[i, j] = w_max0 else: out_w[i, j] = new_value - + else: + out_w[i, j] = old_w[i, j] -dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_pre, - gpu_kernel=_gpu_dense_on_pre) + + dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre) + +else: + dense_on_pre_prim = None + dense_on_post_prim = None def dense_on_pre(weight, spike, trace, w_min, w_max): + if dense_on_pre_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) return dense_on_pre_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): -# out_w[:] = weight -# for i in numba.prange(spike.shape[0]): -# if spike[i]: -# out_w[:, i] = np.clip(out_w[:, i] + trace, w_min, w_max) - -@ti.kernel -def _cpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - -@ti.kernel -def _gpu_dense_on_post(weight: ti.types.ndarray(ndim=2), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=2)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i, j in ti.ndrange(out_w.shape[0], out_w.shape[1]): - out_w[i, j] = weight[i, j] - for i in range(spike.shape[0]): - if spike[i]: - for j in range(out_w.shape[0]): - new_value = out_w[j, i] + trace0 - if new_value < w_min0: - out_w[j, i] = w_min0 - elif new_value > w_max0: - out_w[j, i] = w_max0 - else: - out_w[j, i] = new_value - -dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_cpu_dense_on_post, - gpu_kernel=_gpu_dense_on_post) - - def dense_on_post(weight, spike, trace, w_min, w_max): + if dense_on_post_prim is None: + raise PackageMissingError.by_purpose('taichi', 'custom operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) return dense_on_post_prim(weight, spike, trace, w_min, w_max, @@ -630,7 +595,7 @@ def stdp_update( raise ValueError(f'The shape of weight should be the same as the shape of sparse weight {self.weight.shape}.') if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) - if on_pre is not None: # update on presynaptic spike + if on_pre is not None: # update on presynaptic spike spike = on_pre['spike'] trace = on_pre['trace'] self.weight.value = csr_on_pre_update(self.weight.value, self.indices, self.indptr, spike, trace, w_min, w_max) @@ -682,8 +647,7 @@ def __init__( def update(self, x): if x.ndim == 1: return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - method=self.method, transpose=self.transpose) + shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) elif x.ndim > 1: shapes = x.shape[:-1] x = bm.flatten(x, end_dim=-2) @@ -694,8 +658,8 @@ def update(self, x): def _batch_csrmv(self, x): return bm.sparse.csrmv(self.weight, self.indices, self.indptr, x, - shape=(self.conn.pre_num, self.conn.post_num), - method=self.method, transpose=self.transpose) + shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) + class EventCSRLinear(_CSRLayer): r"""Synaptic matrix multiplication with event CSR sparse computation. @@ -746,99 +710,170 @@ def _batch_csrmv(self, x): shape=(self.conn.pre_num, self.conn.post_num), transpose=self.transpose) -# @numba.njit(nogil=True, fastmath=True, parallel=False) -# def _cpu_csr_on_pre_update(w, indices, indptr, spike, trace, w_min, w_max, out_w): -# out_w[:] = w -# w_min = w_min[()] -# w_max = w_max[()] -# for i in numba.prange(spike.shape[0]): # pre id -# if spike[i]: -# for k in range(indptr[i], indptr[i + 1]): # synapse id -# j = indices[k] # post id -# # out_w[k] = np.clip(out_w[k] + trace[j], w_min, w_max) -# out_w[k] = np.minimum(np.maximum(out_w[k] + trace[j], w_min), w_max) - - -@ti.kernel -def _cpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) -@ti.kernel -def _gpu_csr_on_pre_update(w: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - spike: ti.types.ndarray(ndim=1), - trace: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - out_w: ti.types.ndarray(ndim=1)): - trace0 = trace[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - for i in range(out_w.shape[0]): - out_w[i] = w[i] - for i in range(spike.shape[0]): - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = indices[k] - out_w[k] = min(max(out_w[k] + trace[j], w_min0), w_max0) - - -csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_cpu_csr_on_pre_update, - gpu_kernel=_gpu_csr_on_pre_update) + +if ti is not None: + @ti.kernel + def _csr_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_pre + 1) + spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_pre = spike.shape[0] + for i_pre in range(num_pre): + if spike[i_pre]: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = min(max(old_w[i_syn] + trace[indices[i_syn]], w_min0), w_max0) + else: + for i_syn in range(indptr[i_pre], indptr[i_pre + 1]): + out_w[i_syn] = old_w[i_syn] + + + csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update) + + + @ti.kernel + def _coo_on_pre_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + post_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if pre_spike[pre_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + post_trace[post_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update) + + + @ti.kernel + def _coo_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + pre_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1) # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_syn = old_w.shape[0] + for i_syn in range(num_syn): + if post_spike[post_ids[i_syn]]: # pre spike + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[pre_ids[i_syn]], w_min0), w_max0) + else: + out_w[i_syn] = old_w[i_syn] + + + coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update) + + + # @numba.njit(nogil=True, fastmath=True, parallel=False) + # def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): + # out_w[:] = w + # w_min = w_min[()] + # w_max = w_max[()] + # for i in numba.prange(spike.shape[0]): # post id + # if spike[i]: + # for k in range(indptr[i], indptr[i + 1]): + # j = post_ids[k] # pre id + # l = w_ids[k] # syn id + # out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) + + @ti.kernel + def _csc_on_post_update( + old_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indices: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + indptr: ti.types.ndarray(ndim=1), # vector with shape of (num_post + 1) + w_ids: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + post_spike: ti.types.ndarray(ndim=1), # vector with shape of (num_post,) + pre_trace: ti.types.ndarray(ndim=1), # vector with shape of (num_pre,) + w_min: ti.types.ndarray(ndim=1), # scalar + w_max: ti.types.ndarray(ndim=1), # scalar + out_w: ti.types.ndarray(ndim=1), # vector with shape of (num_syn) + ): + w_min0 = w_min[0] + w_max0 = w_max[0] + num_post = post_spike.shape[0] + for i_post in range(num_post): + if post_spike[i_post]: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = min(max(old_w[i_syn] + pre_trace[indices[k]], w_min0), w_max0) + else: + for k in range(indptr[i_post], indptr[i_post + 1]): + i_syn = w_ids[k] # syn id + out_w[i_syn] = old_w[i_syn] + + + csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update) + + +else: + csr_on_pre_update_prim = None + coo_on_pre_update_prim = None + csc_on_post_update_prim = None def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): + if csr_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - trace = jnp.atleast_1d(trace) w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] -@numba.njit(nogil=True, fastmath=True, parallel=False) -def _cpu_csc_on_pre_update(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, out_w): - out_w[:] = w - w_min = w_min[()] - w_max = w_max[()] - for i in numba.prange(spike.shape[0]): # post id - if spike[i]: - for k in range(indptr[i], indptr[i + 1]): - j = post_ids[k] # pre id - l = w_ids[k] # syn id - out_w[l] = np.minimum(np.maximum(out_w[l] + trace[j], w_min), w_max) +def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None): + if coo_on_pre_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') -csc_on_pre_update_prim = bm.XLACustomOp(_cpu_csc_on_pre_update) - - -def csc_on_post_update(w, post_ids, indptr, w_ids, spike, trace, w_min=None, w_max=None): if w_min is None: w_min = -np.inf if w_max is None: w_max = np.inf - return csc_on_pre_update_prim(w, post_ids, indptr, w_ids, spike, trace, w_min, w_max, + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] +def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min=None, w_max=None): + if csc_on_post_update_prim is None: + raise PackageMissingError.by_purpose('taichi', 'customized operators') + + if w_min is None: + w_min = -np.inf + if w_max is None: + w_max = np.inf + w_min = jnp.atleast_1d(w_min) + w_max = jnp.atleast_1d(w_max) + return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max, + outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] + class CSCLinear(Layer): r"""Synaptic matrix multiplication with CSC sparse computation. diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index 17054667..ba2a49ef 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,6 +1,5 @@ -from absl.testing import absltest from absl.testing import parameterized - +from absl.testing import absltest import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 05f52362..3c9fdfa8 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,15 +1,17 @@ # -*- coding: utf-8 -*- -import jax.numpy as jnp +from unittest import TestCase from absl.testing import absltest +import jax.numpy as jnp +import brainpy.math as bm from absl.testing import parameterized - import brainpy as bp import brainpy.math as bm class TestConv(parameterized.TestCase): def test_Conv2D_img(self): + bm.random.seed() img = jnp.zeros((2, 200, 198, 4)) for k in range(4): x = 30 + 60 * k @@ -22,7 +24,6 @@ def test_Conv2D_img(self): strides=(2, 1), padding='VALID', groups=4) out = net(img) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 99, 196, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(img)[0, :, :, 0]) @@ -30,6 +31,7 @@ def test_Conv2D_img(self): bm.clear_buffer_memory() def test_conv1D(self): + bm.random.seed() with bp.math.training_environment(): model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) @@ -37,7 +39,6 @@ def test_conv1D(self): out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :]) @@ -53,7 +54,6 @@ def test_conv2D(self): out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :, 31]) @@ -67,7 +67,6 @@ def test_conv3D(self): input = bp.math.ones((2, 5, 5, 5, 3)) out = model(input) print("out shape: ", out.shape) - self.assertEqual(out.shape, (2, 5, 5, 5, 32)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 9ad15938..269fec44 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -1,10 +1,12 @@ # -*- coding: utf-8 -*- +from unittest import TestCase + +import jax.numpy as jnp +import brainpy.math as bm from absl.testing import absltest from absl.testing import parameterized - import brainpy as bp -import brainpy.math as bm class TestFunction(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index df5293ab..422f161f 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,218 +1,223 @@ -from absl.testing import absltest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class TestLinear(parameterized.TestCase): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - bm.random.seed() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - num_out=[20, 10, 5] - ) - def test_Dense1(self, size, num_out): - bm.random.seed() - f = bp.dnn.Linear(10, num_out) - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size[:-1] + (num_out,)) - bm.clear_buffer_memory() - - @parameterized.product( - size=[(10,), - (20, 10), - (5, 8, 10)], - ) - def test_Identity(self, size): - bm.random.seed() - f = bp.dnn.Identity() - x = bm.random.random(size) - y = f(x) - self.assertTrue(y.shape == size) - bm.clear_buffer_memory() - - def test_AllToAll1(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((8, 10)) - y = f(x) - expected = bm.sum(x, axis=1, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) - x = bm.random.random((10,)) - y = f(x) - expected = bm.sum(x, keepdims=True) * 0.1 - self.assertTrue(bm.allclose(y, expected)) - bm.clear_buffer_memory() - - def test_OneToOne(self): - bm.random.seed() - with bm.environment(mode=bm.BatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((8, 10)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - - with bm.environment(mode=bm.NonBatchingMode()): - f = bp.dnn.OneToOne(10, weight=.1) - x = bm.random.random((10,)) - y = f(x) - expected = x * 0.1 - self.assertTrue(bm.allclose(y, expected)) - bm.clear_buffer_memory() - - @parameterized.product( - conn=[ - # bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_MaskedLinear(self, conn): - bm.random.seed() - bm.random.DEFAULT.seed(123) - f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - bm.clear_buffer_memory() - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_CSRLinear(self, conn): - bm.random.seed() - f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - conn=[ - bp.conn.FixedProb(0.1, pre=100, post=100), - bp.conn.GridFour(pre=100, post=100), - bp.conn.GaussianProb(0.1, pre=100, post=100), - ] - ) - def test_EventCSRLinear(self,conn): - bm.random.seed() - f=bp.layers.EventCSRLinear(conn,weight=bp.init.Normal()) - x = bm.random.random((16, 100)) - y = f(x) - self.assertTrue(y.shape == (16, 100)) - x = bm.random.random((100,)) - y = f(x) - self.assertTrue(y.shape == (100,)) - bm.clear_buffer_memory() - - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - x = bm.random.random(shape + (100,)) - y = f(x) - self.assertTrue(y.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPHomoLinear(self, prob, weight, shape): - bm.random.seed() - f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): - bm.random.seed() - f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - - @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], - shape=[(), (10,), (10, 20), (10, 20, 25)] - ) - def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): - bm.random.seed() - f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) - y = f(bm.random.random(shape + (100,)) < 0.1) - self.assertTrue(y.shape == shape + (200,)) - - y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) - self.assertTrue(y2.shape == shape + (200,)) - bm.clear_buffer_memory() - -if __name__ == '__main__': - absltest.main() +import pytest +from absl.testing import absltest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class TestLinear(parameterized.TestCase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + bm.random.seed() + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + num_out=[20, 10, 5] + ) + def test_Dense1(self, size, num_out): + bm.random.seed() + f = bp.dnn.Linear(10, num_out) + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size[:-1] + (num_out,)) + bm.clear_buffer_memory() + + @parameterized.product( + size=[(10,), + (20, 10), + (5, 8, 10)], + ) + def test_Identity(self, size): + bm.random.seed() + f = bp.dnn.Identity() + x = bm.random.random(size) + y = f(x) + self.assertTrue(y.shape == size) + bm.clear_buffer_memory() + + def test_AllToAll1(self): + bm.random.seed() + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((8, 10)) + y = f(x) + expected = bm.sum(x, axis=1, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.AllToAll(10, 20, weight=.1, include_self=True) + x = bm.random.random((10,)) + y = f(x) + expected = bm.sum(x, keepdims=True) * 0.1 + self.assertTrue(bm.allclose(y, expected)) + bm.clear_buffer_memory() + + def test_OneToOne(self): + bm.random.seed() + with bm.environment(mode=bm.BatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((8, 10)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + + with bm.environment(mode=bm.NonBatchingMode()): + f = bp.dnn.OneToOne(10, weight=.1) + x = bm.random.random((10,)) + y = f(x) + expected = x * 0.1 + self.assertTrue(bm.allclose(y, expected)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + # bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_MaskedLinear(self, conn): + bm.random.seed() + bm.random.DEFAULT.seed(123) + f = bp.dnn.MaskedLinear(conn, weight=bp.init.XavierNormal(seed=123)) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_CSRLinear(self, conn): + bm.random.seed() + f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + @parameterized.product( + conn=[ + bp.conn.FixedProb(0.1, pre=100, post=100), + bp.conn.GridFour(pre=100, post=100), + bp.conn.GaussianProb(0.1, pre=100, post=100), + ] + ) + def test_EventCSRLinear(self, conn): + bm.random.seed() + f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal()) + x = bm.random.random((16, 100)) + y = f(x) + self.assertTrue(y.shape == (16, 100)) + x = bm.random.random((100,)) + y = f(x) + self.assertTrue(y.shape == (100,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPHomoLinear(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.JitFPHomoLinear(100, 200, prob, weight, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.JitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.JitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + x = bm.random.random(shape + (100,)) + y = f(x) + self.assertTrue(y.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + weight=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPHomoLinear(self, prob, weight, shape): + bm.random.seed() + f = bp.dnn.EventJitFPHomoLinear(100, 200, prob, weight, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.05, 0.5], + w_low=[-0.01, -0.01], + w_high=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): + bm.random.seed() + f = bp.dnn.EventJitFPUniformLinear(100, 200, prob, w_low, w_high, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + @parameterized.product( + prob=[0.01, 0.1, 0.5], + w_mu=[-0.01, -0.01], + w_sigma=[0.01, 0.01], + shape=[(), (10,), (10, 20), (10, 20, 25)] + ) + def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): + bm.random.seed() + f = bp.dnn.EventJitFPNormalLinear(100, 200, prob, w_mu, w_sigma, seed=123) + y = f(bm.random.random(shape + (100,)) < 0.1) + self.assertTrue(y.shape == shape + (200,)) + + y2 = f(bm.as_jax(bm.random.random(shape + (100,)) < 0.1, dtype=float)) + self.assertTrue(y2.shape == shape + (200,)) + bm.clear_buffer_memory() + + +if __name__ == '__main__': + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 3cf923d7..f0c67da1 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -1,801 +1,807 @@ -from absl.testing import absltest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class Test_Conv(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(50, 3) - layer = bp.dnn.Conv1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv2_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.Conv2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_Conv3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode) - output = layer(input) - bm.clear_buffer_memory() - - def test_Conv3_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.Conv3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose1d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 3) - layer = bp.dnn.ConvTranspose1d(in_channels=3, - out_channels=4, - kernel_size=5, - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose2d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 3) - layer = bp.dnn.ConvTranspose2d(in_channels=3, - out_channels=4, - kernel_size=(5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), ] - ) - def test_ConvTranspose3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 50, 50, 50, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=mode - ) - output = layer(input) - bm.clear_buffer_memory() - - def test_ConvTranspose3d_NonBatching(self): - bm.random.seed() - input = bm.random.randn(10, 10, 10, 3) - layer = bp.dnn.ConvTranspose3d(in_channels=3, - out_channels=4, - kernel_size=(5, 5, 5), - mode=bm.NonBatchingMode()) - output = layer(input) - bm.clear_buffer_memory() - - -class TestPool(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MinPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AvgPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.MaxPool1d(kernel_size=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveAvgPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool1d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 4) - layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool2d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AdaptiveMaxPool3d(self, mode): - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), - channel_axis=-1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - -class Test_Dropout(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Dropout(self, mode): - bp.share.save(fit=False) - bm.random.seed() - input = bm.random.randn(10, 5, 5, 5, 4) - layer = bp.dnn.Dropout(prob=0.2, - mode=mode) - output = layer(input) - - -class Test_function(parameterized.TestCase): - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_Flatten(self, mode): - bm.random.seed() - layer = bp.dnn.Flatten(mode=mode) - input = bm.random.randn(10, 5, 5, 5, 4) - output = layer(input) - - -class Test_linear(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_linear(self, mode): - bm.random.seed() - input = bm.random.randn(10, 9, 8, 7) - layer = bp.dnn.Linear(num_in=7, - num_out=6, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_AllToAll(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.AllToAll(num_pre=10, - num_post=20, - weight=0.1, - mode=mode) - if mode in [bm.NonBatchingMode()]: - for i in input: - output = layer(i) - else: - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_OneToOne(self, mode): - bm.random.seed() - input = bm.random.randn(10, 10) - layer = bp.dnn.OneToOne(num=10, - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_MaskedLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_CSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventCSRLinear(self, mode): - bm.random.seed() - input = bm.random.randn(100, 100) - layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), - weight=0.1, - mode=mode) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_JitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.JitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPHomoLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPHomoLinear(num_in=100, - num_out=200, - prob=0.1, - weight=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPNormalLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPNormalLinear(num_in=100, - num_out=200, - prob=0.1, - w_mu=-0.01, - w_sigma=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()] - ) - def test_EventJitFPUniformLinear(self, mode): - bm.random.seed() - layer = bp.dnn.EventJitFPUniformLinear(num_in=100, - num_out=200, - prob=0.1, - w_low=-0.01, - w_high=0.01, - seed=100, - mode=mode) - input = bm.random.randn(10, 100) - output = layer(input) - - -class Test_Normalization(parameterized.TestCase): - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm1d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm1d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm2d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm2d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10)], - fit=[True, False] - ) - def test_BatchNorm3d(self, fit, mode): - bm.random.seed() - bp.share.save(fit=fit) - layer = bp.dnn.BatchNorm3d(num_features=100, - mode=mode, - affine=False) - input = bm.random.randn(10, 5, 6, 7, 100) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_LayerNorm(self, mode): - bm.random.seed() - layer = bp.dnn.LayerNorm(normalized_shape=3, - mode=mode, - elementwise_affine=False - ) - input = bm.random.randn(10, 5, 3) - outout = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_GroupNorm(self, mode): - bm.random.seed() - layer = bp.dnn.GroupNorm(num_groups=2, - num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - @parameterized.product( - mode=[bm.TrainingMode(), - bm.TrainingMode(10), - bm.BatchingMode(), - bm.BatchingMode(10), - bm.NonBatchingMode()], - ) - def test_InstanceNorm(self, mode): - bm.random.seed() - layer = bp.dnn.InstanceNorm(num_channels=6, - affine=False, - mode=mode - ) - input = bm.random.randn(20, 10, 10, 6) - output = layer(input) - - -if __name__ == '__main__': - absltest.main() +import pytest +from absl.testing import absltest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class Test_Conv(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 3) + layer = bp.dnn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv1d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(50, 3) + layer = bp.dnn.Conv1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 3) + layer = bp.dnn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv2_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dnn.Conv2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_Conv3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 50, 3) + layer = bp.dnn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=mode) + output = layer(input) + bm.clear_buffer_memory() + + def test_Conv3_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dnn.Conv3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 3) + layer = bp.dnn.ConvTranspose1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose1d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 3) + layer = bp.dnn.ConvTranspose1d(in_channels=3, + out_channels=4, + kernel_size=5, + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 3) + layer = bp.dnn.ConvTranspose2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose2d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 3) + layer = bp.dnn.ConvTranspose2d(in_channels=3, + out_channels=4, + kernel_size=(5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), ] + ) + def test_ConvTranspose3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 50, 50, 50, 3) + layer = bp.dnn.ConvTranspose3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=mode + ) + output = layer(input) + bm.clear_buffer_memory() + + def test_ConvTranspose3d_NonBatching(self): + bm.random.seed() + input = bm.random.randn(10, 10, 10, 3) + layer = bp.dnn.ConvTranspose3d(in_channels=3, + out_channels=4, + kernel_size=(5, 5, 5), + mode=bm.NonBatchingMode()) + output = layer(input) + bm.clear_buffer_memory() + + +class TestPool(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MinPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AvgPool(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AvgPool1d(kernel_size=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AvgPool2d(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AvgPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AvgPool3d(kernel_size=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.MaxPool1d(kernel_size=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.MaxPool2d(kernel_size=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaxPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.MaxPool3d(kernel_size=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AdaptiveAvgPool1d(target_shape=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AdaptiveAvgPool2d(target_shape=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveAvgPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AdaptiveAvgPool3d(target_shape=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool1d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 4) + layer = bp.dnn.AdaptiveMaxPool1d(target_shape=3, + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool2d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 4) + layer = bp.dnn.AdaptiveMaxPool2d(target_shape=(3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AdaptiveMaxPool3d(self, mode): + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.AdaptiveMaxPool3d(target_shape=(3, 3, 3), + channel_axis=-1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + +class Test_Dropout(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_Dropout(self, mode): + bp.share.save(fit=False) + bm.random.seed() + input = bm.random.randn(10, 5, 5, 5, 4) + layer = bp.dnn.Dropout(prob=0.2, + mode=mode) + output = layer(input) + + +class Test_function(parameterized.TestCase): + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_Flatten(self, mode): + bm.random.seed() + layer = bp.dnn.Flatten(mode=mode) + input = bm.random.randn(10, 5, 5, 5, 4) + output = layer(input) + + +class Test_linear(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_linear(self, mode): + bm.random.seed() + input = bm.random.randn(10, 9, 8, 7) + layer = bp.dnn.Linear(num_in=7, + num_out=6, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_AllToAll(self, mode): + bm.random.seed() + input = bm.random.randn(10, 10) + layer = bp.dnn.AllToAll(num_pre=10, + num_post=20, + weight=0.1, + mode=mode) + if mode in [bm.NonBatchingMode()]: + for i in input: + output = layer(i) + else: + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_OneToOne(self, mode): + bm.random.seed() + input = bm.random.randn(10, 10) + layer = bp.dnn.OneToOne(num=10, + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_MaskedLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.MaskedLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_CSRLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.CSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventCSRLinear(self, mode): + bm.random.seed() + input = bm.random.randn(100, 100) + layer = bp.dnn.EventCSRLinear(conn=bp.conn.FixedProb(0.1, pre=100, post=100), + weight=0.1, + mode=mode) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPHomoLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPHomoLinear(num_in=100, + num_out=200, + prob=0.1, + weight=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPUniformLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPUniformLinear(num_in=100, + num_out=200, + prob=0.1, + w_low=-0.01, + w_high=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_JitFPNormalLinear(self, mode): + bm.random.seed() + layer = bp.dnn.JitFPNormalLinear(num_in=100, + num_out=200, + prob=0.1, + w_mu=-0.01, + w_sigma=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPHomoLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPHomoLinear(num_in=100, + num_out=200, + prob=0.1, + weight=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPNormalLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPNormalLinear(num_in=100, + num_out=200, + prob=0.1, + w_mu=-0.01, + w_sigma=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()] + ) + def test_EventJitFPUniformLinear(self, mode): + bm.random.seed() + layer = bp.dnn.EventJitFPUniformLinear(num_in=100, + num_out=200, + prob=0.1, + w_low=-0.01, + w_high=0.01, + seed=100, + mode=mode) + input = bm.random.randn(10, 100) + output = layer(input) + + +class Test_Normalization(parameterized.TestCase): + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm1d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm1d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm2d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm2d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 6, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10)], + fit=[True, False] + ) + def test_BatchNorm3d(self, fit, mode): + bm.random.seed() + bp.share.save(fit=fit) + layer = bp.dnn.BatchNorm3d(num_features=100, + mode=mode, + affine=False) + input = bm.random.randn(10, 5, 6, 7, 100) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_LayerNorm(self, mode): + bm.random.seed() + layer = bp.dnn.LayerNorm(normalized_shape=3, + mode=mode, + elementwise_affine=False + ) + input = bm.random.randn(10, 5, 3) + outout = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_GroupNorm(self, mode): + bm.random.seed() + layer = bp.dnn.GroupNorm(num_groups=2, + num_channels=6, + affine=False, + mode=mode + ) + input = bm.random.randn(20, 10, 10, 6) + output = layer(input) + + @parameterized.product( + mode=[bm.TrainingMode(), + bm.TrainingMode(10), + bm.BatchingMode(), + bm.BatchingMode(10), + bm.NonBatchingMode()], + ) + def test_InstanceNorm(self, mode): + bm.random.seed() + layer = bp.dnn.InstanceNorm(num_channels=6, + affine=False, + mode=mode + ) + input = bm.random.randn(20, 10, 10, 6) + output = layer(input) + + +if __name__ == '__main__': + absltest.main() diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index de2c9765..fdc5b34e 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,8 +1,7 @@ -from absl.testing import absltest +import brainpy.math as bm from absl.testing import parameterized - +from absl.testing import absltest import brainpy as bp -import brainpy.math as bm class Test_Normalization(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 5748edd8..34f8f5cd 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import absltest from absl.testing import parameterized +from absl.testing import absltest import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index b8884f32..18d9d9dc 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -1,120 +1,127 @@ -# -*- coding: utf-8 -*- - - -import numpy as np -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - - -class Test_STDP(parameterized.TestCase): - - @parameterized.product( - comm_method=['dense', 'csr', 'masked_linear', 'all2all', 'one2one'], - delay=[None, 0., 2.], - syn_model=['exp', 'dual_exp', 'ampa'], - out_model=['cuba', 'coba', 'mg'] - ) - def test_STDP(self, comm_method, delay, syn_model, out_model): - bm.random.seed() - - class STDPNet(bp.DynamicalSystem): - def __init__(self, num_pre, num_post): - super().__init__() - self.pre = bp.dyn.LifRef(num_pre) - self.post = bp.dyn.LifRef(num_post) - - if comm_method == 'all2all': - comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) - elif comm_method == 'csr': - if syn_model == 'exp': - comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - else: - comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - elif comm_method == 'masked_linear': - comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - weight=bp.init.Uniform(0., 0.1)) - elif comm_method == 'dense': - comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) - elif comm_method == 'one2one': - comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) - else: - raise ValueError - - if syn_model == 'exp': - syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) - elif syn_model == 'dual_exp': - syn = bp.dyn.DualExpon.desc(self.post.varshape) - elif syn_model == 'dual_exp_v2': - syn = bp.dyn.DualExponV2.desc(self.post.varshape) - elif syn_model == 'ampa': - syn = bp.dyn.AMPA.desc(self.post.varshape) - else: - raise ValueError - - if out_model == 'cuba': - out = bp.dyn.CUBA.desc() - elif out_model == 'coba': - out = bp.dyn.COBA.desc(E=0.) - elif out_model == 'mg': - out = bp.dyn.MgBlock.desc(E=0.) - else: - raise ValueError - - self.syn = bp.dyn.STDP_Song2000( - pre=self.pre, - delay=delay, - comm=comm, - syn=syn, - out=out, - post=self.post, - tau_s=16.8, - tau_t=33.7, - A1=0.96, - A2=0.53, - W_min=0., - W_max=1. - ) - - def update(self, I_pre, I_post): - self.syn() - self.pre(I_pre) - self.post(I_post) - conductance = self.syn.refs['syn'].g - Apre = self.syn.refs['pre_trace'].g - Apost = self.syn.refs['post_trace'].g - current = self.post.sum_current_inputs(self.post.V) - if comm_method == 'dense': - w = self.syn.comm.W.flatten() - else: - w = self.syn.comm.weight.flatten() - return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w - - duration = 300. - I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) - I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], - [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) - - net = STDPNet(1, 1) - - def run(i, I_pre, I_post): - pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) - return pre_spike, post_spike, g, Apre, Apost, current, W - - indices = np.arange(int(duration / bm.dt)) - pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) - - # import matplotlib.pyplot as plt - # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) - # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) - # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) - # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) - # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) - # plt.show() - - bm.clear_buffer_memory() - +# -*- coding: utf-8 -*- + +import numpy as np +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +bm.set_platform('cpu') + + +class Test_STDP(parameterized.TestCase): + + @parameterized.product( + comm_method=['csr', 'dense', 'masked_linear', 'all2all', 'one2one'], + delay=[None, 0., 2.], + syn_model=['exp', 'dual_exp', 'ampa'], + out_model=['cuba', 'coba', 'mg'] + ) + def test_STDP(self, comm_method, delay, syn_model, out_model): + bm.random.seed() + + class STDPNet(bp.DynamicalSystem): + def __init__(self, num_pre, num_post): + super().__init__() + self.pre = bp.dyn.LifRef(num_pre) + self.post = bp.dyn.LifRef(num_post) + + if comm_method == 'all2all': + comm = bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'csr': + if syn_model == 'exp': + comm = bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + else: + comm = bp.dnn.CSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'masked_linear': + comm = bp.dnn.MaskedLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), + weight=bp.init.Uniform(0., 0.1)) + elif comm_method == 'dense': + comm = bp.dnn.Dense(self.pre.num, self.post.num, W_initializer=bp.init.Uniform(.1, 0.1)) + elif comm_method == 'one2one': + comm = bp.dnn.OneToOne(self.pre.num, weight=bp.init.Uniform(.1, 0.1)) + else: + raise ValueError + + if syn_model == 'exp': + syn = bp.dyn.Expon.desc(self.post.varshape, tau=5.) + elif syn_model == 'dual_exp': + syn = bp.dyn.DualExpon.desc(self.post.varshape) + elif syn_model == 'dual_exp_v2': + syn = bp.dyn.DualExponV2.desc(self.post.varshape) + elif syn_model == 'ampa': + syn = bp.dyn.AMPA.desc(self.post.varshape) + else: + raise ValueError + + if out_model == 'cuba': + out = bp.dyn.CUBA.desc() + elif out_model == 'coba': + out = bp.dyn.COBA.desc(E=0.) + elif out_model == 'mg': + out = bp.dyn.MgBlock.desc(E=0.) + else: + raise ValueError + + self.syn = bp.dyn.STDP_Song2000( + pre=self.pre, + delay=delay, + comm=comm, + syn=syn, + out=out, + post=self.post, + tau_s=16.8, + tau_t=33.7, + A1=0.96, + A2=0.53, + W_min=0., + W_max=1. + ) + + def update(self, I_pre, I_post): + self.syn() + self.pre(I_pre) + self.post(I_post) + conductance = self.syn.refs['syn'].g + Apre = self.syn.refs['pre_trace'].g + Apost = self.syn.refs['post_trace'].g + current = self.post.sum_current_inputs(self.post.V) + if comm_method == 'dense': + w = self.syn.comm.W.flatten() + else: + w = self.syn.comm.weight.flatten() + return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, w + + duration = 300. + I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, + duration - 255]) + I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], + [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, + duration - 250]) + + net = STDPNet(1, 1) + + def run(i, I_pre, I_post): + pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) + return pre_spike, post_spike, g, Apre, Apost, current, W + + indices = np.arange(int(duration / bm.dt)) + pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) + + # import matplotlib.pyplot as plt + # fig, gs = bp.visualize.get_figure(4, 1, 3, 10) + # bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0])) + # bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0])) + # bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0])) + # bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0])) + # plt.show() + + bm.clear_buffer_memory() diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py index 90500a26..eec2c945 100644 --- a/brainpy/_src/dyn/projections/tests/test_aligns.py +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -1,439 +1,444 @@ -import matplotlib.pyplot as plt -import numpy as np - -import brainpy as bp -import brainpy.math as bm - -neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - - -def test_ProjAlignPreMg1(): - class EICOBA_PreAlign(bp.DynamicalSystem): - def __init__(self, scale=1., inp=20., delay=None): - super().__init__() - - self.inp = inp - self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) - - prob = 80 / (4000 * scale) - - self.E2I = bp.dyn.FullProjAlignPreSDMg( - pre=self.E, - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), - out=bp.dyn.COBA(E=0.), - post=self.I, - ) - self.E2E = bp.dyn.FullProjAlignPreSDMg( - pre=self.E, - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), - out=bp.dyn.COBA(E=0.), - post=self.E, - ) - self.I2E = bp.dyn.FullProjAlignPreSDMg( - pre=self.I, - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E, - ) - self.I2I = bp.dyn.FullProjAlignPreSDMg( - pre=self.I, - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - delay=delay, - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I, - ) - - def update(self): - self.E2I() - self.I2I() - self.I2E() - self.E2E() - self.E(self.inp) - self.I(self.inp) - return self.E.spike.value - - net = EICOBA_PreAlign(0.5) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PreAlign(0.5, delay=1.) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - plt.close() - bm.clear_buffer_memory() - - -def test_ProjAlignPostMg2(): - class EICOBA_PostAlign(bp.DynamicalSystem): - def __init__(self, scale, inp=20., ltc=True, delay=None): - super().__init__() - self.inp = inp - - if ltc: - self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) - else: - self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) - self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) - - prob = 80 / (4000 * scale) - - self.E2E = bp.dyn.FullProjAlignPostMg( - pre=self.E, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), - syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.E, - ) - self.E2I = bp.dyn.FullProjAlignPostMg( - pre=self.E, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), - syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), - out=bp.dyn.COBA.desc(E=0.), - post=self.I, - ) - self.I2E = bp.dyn.FullProjAlignPostMg( - pre=self.I, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), - syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.E, - ) - self.I2I = bp.dyn.FullProjAlignPostMg( - pre=self.I, - delay=delay, - comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), - syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), - out=bp.dyn.COBA.desc(E=-80.), - post=self.I, - ) - - def update(self): - self.E2I() - self.I2I() - self.I2E() - self.E2E() - self.E(self.inp) - self.I(self.inp) - return self.E.spike.value - - net = EICOBA_PostAlign(0.5) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PostAlign(0.5, delay=1.) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - net = EICOBA_PostAlign(0.5, ltc=False) - indices = np.arange(400) - spks = bm.for_loop(net.step_run, indices) - bp.visualize.raster_plot(indices * bm.dt, spks, show=True) - - plt.close() - bm.clear_buffer_memory() - - -def test_ProjAlignPost1(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): - super().__init__() - num = int(4000 * scale) - self.num_exc = int(3200 * scale) - self.num_inh = num - self.num_exc - prob = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), - syn=bp.dyn.Expon(size=num, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), - syn=bp.dyn.Expon(size=num, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(spk[:self.num_exc]) - self.I(spk[self.num_exc:]) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet(0.5) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPost2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale, delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (ne + ni) - - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ne, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), - syn=bp.dyn.Expon(size=ni, tau=5.), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ne, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, - delay=delay, - comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), - syn=bp.dyn.Expon(size=ni, tau=10.), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet(0.5, delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(0.5, delay=None) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_VanillaProj(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=0.5): - super().__init__() - num = int(4000 * scale) - self.ne = int(3200 * scale) - self.ni = num - self.ne - p = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) - self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) - self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N) - self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N) - - def update(self, input): - spk = self.delay.at('I') - self.E(self.syn1(spk[:self.ne])) - self.I(self.syn2(spk[self.ne:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPreMg1_v2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1., delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (4000 * scale) - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - delay=delay, - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_ProjAlignPreMg2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1., delay=None): - super().__init__() - ne, ni = int(3200 * scale), int(800 * scale) - p = 80 / (4000 * scale) - self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.)) - self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.E) - self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, - delay=delay, - syn=bp.dyn.Expon.desc(size=ne, tau=5.), - comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.I) - self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.E) - self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, - delay=delay, - syn=bp.dyn.Expon.desc(size=ni, tau=10.), - comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.I) - - def update(self, inp): - self.E2E() - self.E2I() - self.I2E() - self.I2I() - self.E(inp) - self.I(inp) - return self.E.spike - - model = EINet(scale=0.2, delay=None) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - model = EINet(scale=0.2, delay=1.) - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) - bp.visualize.raster_plot(indices, spks, show=True) - - bm.clear_buffer_memory() - plt.close() - - -def test_vanalla_proj_v2(): - class EINet(bp.DynSysGroup): - def __init__(self, scale=1.): - super().__init__() - num = int(4000 * scale) - self.ne = int(3200 * scale) - self.ni = num - self.ne - p = 80 / num - - self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 1.)) - self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) - self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) - self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) - self.E = bp.dyn.VanillaProj( - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), - out=bp.dyn.COBA(E=0.), - post=self.N - ) - self.I = bp.dyn.VanillaProj( - comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), - out=bp.dyn.COBA(E=-80.), - post=self.N - ) - - def update(self, input): - spk = self.delay.at('delay') - self.E(self.syn1(spk[:self.ne])) - self.I(self.syn2(spk[self.ne:])) - self.delay(self.N(input)) - return self.N.spike.value - - model = EINet() - indices = bm.arange(400) - spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) - bp.visualize.raster_plot(indices, spks, show=True) - plt.close() - bm.clear_buffer_memory() - +import pytest +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + + +def test_ProjAlignPreMg1(): + class EICOBA_PreAlign(bp.DynamicalSystem): + def __init__(self, scale=1., inp=20., delay=None): + super().__init__() + + self.inp = inp + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2I = bp.dyn.FullProjAlignPreSDMg( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.I, + ) + self.E2E = bp.dyn.FullProjAlignPreSDMg( + pre=self.E, + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + out=bp.dyn.COBA(E=0.), + post=self.E, + ) + self.I2E = bp.dyn.FullProjAlignPreSDMg( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.FullProjAlignPreSDMg( + pre=self.I, + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + delay=delay, + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PreAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PreAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + plt.close() + bm.clear_buffer_memory() + + +def test_ProjAlignPostMg2(): + class EICOBA_PostAlign(bp.DynamicalSystem): + def __init__(self, scale, inp=20., ltc=True, delay=None): + super().__init__() + self.inp = inp + + if ltc: + self.E = bp.dyn.LifRefLTC(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRefLTC(int(800 * scale), **neu_pars) + else: + self.E = bp.dyn.LifRef(int(3200 * scale), **neu_pars) + self.I = bp.dyn.LifRef(int(800 * scale), **neu_pars) + + prob = 80 / (4000 * scale) + + self.E2E = bp.dyn.FullProjAlignPostMg( + pre=self.E, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.E, + ) + self.E2I = bp.dyn.FullProjAlignPostMg( + pre=self.E, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.), + out=bp.dyn.COBA.desc(E=0.), + post=self.I, + ) + self.I2E = bp.dyn.FullProjAlignPostMg( + pre=self.I, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7), + syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.E, + ) + self.I2I = bp.dyn.FullProjAlignPostMg( + pre=self.I, + delay=delay, + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7), + syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.), + out=bp.dyn.COBA.desc(E=-80.), + post=self.I, + ) + + def update(self): + self.E2I() + self.I2I() + self.I2E() + self.E2E() + self.E(self.inp) + self.I(self.inp) + return self.E.spike.value + + net = EICOBA_PostAlign(0.5) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PostAlign(0.5, delay=1.) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + net = EICOBA_PostAlign(0.5, ltc=False) + indices = np.arange(400) + spks = bm.for_loop(net.step_run, indices) + bp.visualize.raster_plot(indices * bm.dt, spks, show=True) + + plt.close() + bm.clear_buffer_memory() + + +def test_ProjAlignPost1(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.num_exc = int(3200 * scale) + self.num_inh = num - self.num_exc + prob = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_exc, num, prob=prob, weight=0.6), + syn=bp.dyn.Expon(size=num, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.HalfProjAlignPost(comm=bp.dnn.EventJitFPHomoLinear(self.num_inh, num, prob=prob, weight=6.7), + syn=bp.dyn.Expon(size=num, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:self.num_exc]) + self.I(spk[self.num_exc:]) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet(0.5) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPost2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale, delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (ne + ni) + + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ne, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPost(pre=self.E, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6), + syn=bp.dyn.Expon(size=ni, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ne, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPost(pre=self.I, + delay=delay, + comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7), + syn=bp.dyn.Expon(size=ni, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(0.5, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(0.5, delay=None) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_VanillaProj(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=0.5): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ne, num, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N) + self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(self.ni, num, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N) + + def update(self, input): + spk = self.delay.at('I') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPreMg1_v2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1., delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreSDMg(pre=self.E, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreSDMg(pre=self.I, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + delay=delay, + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_ProjAlignPreMg2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1., delay=None): + super().__init__() + ne, ni = int(3200 * scale), int(800 * scale) + p = 80 / (4000 * scale) + self.E = bp.dyn.LifRefLTC(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.E2E = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.E) + self.E2I = bp.dyn.FullProjAlignPreDSMg(pre=self.E, + delay=delay, + syn=bp.dyn.Expon.desc(size=ne, tau=5.), + comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.I) + self.I2E = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.E) + self.I2I = bp.dyn.FullProjAlignPreDSMg(pre=self.I, + delay=delay, + syn=bp.dyn.Expon.desc(size=ni, tau=10.), + comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.I) + + def update(self, inp): + self.E2E() + self.E2I() + self.I2E() + self.I2I() + self.E(inp) + self.I(inp) + return self.E.spike + + model = EINet(scale=0.2, delay=None) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + model = EINet(scale=0.2, delay=1.) + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices) + bp.visualize.raster_plot(indices, spks, show=True) + + bm.clear_buffer_memory() + plt.close() + + +def test_vanalla_proj_v2(): + class EINet(bp.DynSysGroup): + def __init__(self, scale=1.): + super().__init__() + num = int(4000 * scale) + self.ne = int(3200 * scale) + self.ni = num - self.ne + p = 80 / num + + self.N = bp.dyn.LifRefLTC(num, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 1.)) + self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) + self.syn1 = bp.dyn.Expon(size=self.ne, tau=5.) + self.syn2 = bp.dyn.Expon(size=self.ni, tau=10.) + self.E = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ne, post=num), weight=0.6), + out=bp.dyn.COBA(E=0.), + post=self.N + ) + self.I = bp.dyn.VanillaProj( + comm=bp.dnn.CSRLinear(bp.conn.FixedProb(p, pre=self.ni, post=num), weight=6.7), + out=bp.dyn.COBA(E=-80.), + post=self.N + ) + + def update(self, input): + spk = self.delay.at('delay') + self.E(self.syn1(spk[:self.ne])) + self.I(self.syn2(spk[self.ne:])) + self.delay(self.N(input)) + return self.N.spike.value + + model = EINet() + indices = bm.arange(400) + spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices, progress_bar=True) + bp.visualize.raster_plot(indices, spks, show=True) + plt.close() + bm.clear_buffer_memory() diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index badb6083..c3936f68 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -1,126 +1,130 @@ -# -*- coding: utf-8 -*- - - -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm -from brainpy._src.dynold.synapses import abstract_models - - -class Test_Abstract_Synapse(parameterized.TestCase): - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_all2all_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(bp.synapses, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_one2one_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - comp_type=['sparse', 'dense'], - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_sparse_synapse(self, comp_type, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size, ) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - post_ref_key=[None, 'refractory'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_delta_synapse(self, post_ref_key, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5, ref_var=True) - post_neu = bp.neurons.LIF(3, ref_var=True) - syn = bp.synapses.Delta(pre_neu, post_neu, - conn=bp.conn.All2All(), - post_ref_key=post_ref_key, - stp=stp, ) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - pre_expected_shape = (100, 5) - post_expected_shape = (100, 3) - if isinstance(mode, bm.BatchingMode): - pre_expected_shape = (mode.batch_size,) + pre_expected_shape - post_expected_shape = (mode.batch_size,) + post_expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dynold.synapses import abstract_models +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class Test_Abstract_Synapse(parameterized.TestCase): + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_all2all_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(bp.synapses, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_one2one_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + comp_type=['sparse', 'dense'], + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_sparse_synapse(self, comp_type, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + post_ref_key=[None, 'refractory'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_delta_synapse(self, post_ref_key, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5, ref_var=True) + post_neu = bp.neurons.LIF(3, ref_var=True) + syn = bp.synapses.Delta(pre_neu, post_neu, + conn=bp.conn.All2All(), + post_ref_key=post_ref_key, + stp=stp, ) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + pre_expected_shape = (100, 5) + post_expected_shape = (100, 3) + if isinstance(mode, bm.BatchingMode): + pre_expected_shape = (mode.batch_size,) + pre_expected_shape + post_expected_shape = (mode.batch_size,) + post_expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) + bm.clear_buffer_memory() diff --git a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py index 39586809..01a31526 100644 --- a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py @@ -1,103 +1,108 @@ -# -*- coding: utf-8 -*- - - -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm - -biological_models = [ - bp.synapses.AMPA, - bp.synapses.GABAa, - bp.synapses.BioNMDA, -] - - -class Test_Biological_Synapse(parameterized.TestCase): - @parameterized.product( - synapse=biological_models, - delay_step=[None, 5, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_all2all_synapse(self, synapse, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - synapse=biological_models, - delay_step=[None, 10, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_one2one_synapse(self, synapse, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - synapse=biological_models, - comp_method=['sparse', 'dense'], - delay_step=[None, 10, 1], - mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], - stp=[None, bp.synplast.STP(), bp.synplast.STD()] - ) - def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(10) - post_neu = bp.neurons.LIF(10) - syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), - comp_method=comp_method, delay_step=delay_step, - stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 10) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + +biological_models = [ + bp.synapses.AMPA, + bp.synapses.GABAa, + bp.synapses.BioNMDA, +] + + +class Test_Biological_Synapse(parameterized.TestCase): + @parameterized.product( + synapse=biological_models, + delay_step=[None, 5, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_all2all_synapse(self, synapse, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn = synapse(pre_neu, post_neu, conn=bp.conn.All2All(), delay_step=delay_step, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + synapse=biological_models, + delay_step=[None, 10, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5), ], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_one2one_synapse(self, synapse, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn = synapse(pre_neu, post_neu, conn=bp.conn.One2One(), delay_step=delay_step, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + synapse=biological_models, + comp_method=['sparse', 'dense'], + delay_step=[None, 10, 1], + mode=[bm.NonBatchingMode(), bm.BatchingMode(5)], + stp=[None, bp.synplast.STP(), bp.synplast.STD()] + ) + def test_sparse_synapse(self, synapse, comp_method, delay_step, mode, stp): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(10) + post_neu = bp.neurons.LIF(10) + syn = synapse(pre_neu, post_neu, conn=bp.conn.FixedProb(0.5), + comp_method=comp_method, delay_step=delay_step, + stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 10) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size,) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 19aca92c..6ebe9dc2 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -7,7 +7,7 @@ __all__ = ['mode', 'membrane_scaling', 'dt', 'bool_', 'int_', 'ti_int', 'float_', 'ti_float', 'complex_'] -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) # Default computation mode. mode = NonBatchingMode() @@ -24,15 +24,19 @@ # '''Default integer data type.''' int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32 -# '''Default integer data type in Taichi.''' -ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 - # '''Default float data type.''' float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32 -# '''Default float data type in Taichi.''' -ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 - # '''Default complex data type.''' complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 +if ti is not None: + # '''Default integer data type in Taichi.''' + ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 + + # '''Default float data type in Taichi.''' + ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 + +else: + ti_int = None + ti_float = None diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index 676e4286..eb8e27c8 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -11,7 +11,7 @@ from brainpy import check from brainpy.check import is_float, is_integer, jit_error from brainpy.errors import UnsupportedError -from .compat_numpy import broadcast_to, expand_dims, concatenate +from .compat_numpy import vstack, broadcast_to from .environment import get_dt, get_float from .interoperability import as_jax from .ndarray import ndarray, Array @@ -392,7 +392,6 @@ def reset( dtype=delay_target.dtype), batch_axis=batch_axis) else: - self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) @@ -473,7 +472,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) + self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) else: self.data[:] = value diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 1c8b98a3..668f837c 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -18,7 +18,7 @@ from . import defaults from brainpy._src.dependency_check import import_taichi -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ # context manage for environment setting @@ -416,13 +416,16 @@ def set_float(dtype: type): """ if dtype in [jnp.float16, 'float16', 'f16']: defaults.__dict__['float_'] = jnp.float16 - defaults.__dict__['ti_float'] = ti.float16 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float16 elif dtype in [jnp.float32, 'float32', 'f32']: defaults.__dict__['float_'] = jnp.float32 - defaults.__dict__['ti_float'] = ti.float32 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float32 elif dtype in [jnp.float64, 'float64', 'f64']: defaults.__dict__['float_'] = jnp.float64 - defaults.__dict__['ti_float'] = ti.float64 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float64 else: raise NotImplementedError @@ -448,16 +451,20 @@ def set_int(dtype: type): """ if dtype in [jnp.int8, 'int8', 'i8']: defaults.__dict__['int_'] = jnp.int8 - defaults.__dict__['ti_int'] = ti.int8 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int8 elif dtype in [jnp.int16, 'int16', 'i16']: defaults.__dict__['int_'] = jnp.int16 - defaults.__dict__['ti_int'] = ti.int16 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int16 elif dtype in [jnp.int32, 'int32', 'i32']: defaults.__dict__['int_'] = jnp.int32 - defaults.__dict__['ti_int'] = ti.int32 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int32 elif dtype in [jnp.int64, 'int64', 'i64']: defaults.__dict__['int_'] = jnp.int64 - defaults.__dict__['ti_int'] = ti.int64 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int64 else: raise NotImplementedError diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 63112955..bdd3102a 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,4 +1,2 @@ - -from ._info_collection import * from ._csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 6e03be46..6b7f7da0 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -10,34 +10,25 @@ """ -from functools import partial from typing import Union, Tuple import jax import jax.numpy as jnp -import numba import numpy as np -from jax.core import ShapedArray, Primitive -from jax.interpreters import ad, xla -from jax.lib import xla_client +from jax.interpreters import ad -from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, - XLACustomOp) -from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv +from brainpy._src.math.op_register import XLACustomOp from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy.errors import GPUOperatorNotFound +from brainpy.errors import PackageMissingError __all__ = [ 'csrmv' ] -ti = import_taichi() - +ti = import_taichi(error_if_not_found=False) def csrmv( data: Union[float, jax.Array], @@ -53,577 +44,6 @@ def csrmv( This function supports JAX transformations, including `jit()`, `grad()`, `vmap()` and `pmap()`. - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - - -### BRAINPYLIB ### - -def csrmv_brainpylib( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError('indices should be a 1D vector with int32 or int64 type.') - if indptr.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError('indptr should be a 1D vector with int32 or int64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # computing - return event_csr_matvec_p.bind(data, indices, indptr, events, shape=shape, transpose=transpose) - - -# ---------------------------------------------------------- -# event csr matvec -# ---------------------------------------------------------- - -# operator for `event_csr_matvec` batching rule -# -------- - -def _batch_event_csr_matvec_abstract( - values, indices, indptr, events, *, batch_size, shape, transpose=False -): - return ShapedArray(dtype=values.dtype, shape=(batch_size, shape[1] if transpose else shape[0])) - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_csr_matvec_transpose_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, batch_size, shape, _ = ins - batch_size = batch_size[()] - event_batch_dim = events.shape[0] - indices_batch_dim = indices.shape[0] - indptr_batch_dim = indptr.shape[0] - values_batch_dim = values.shape[0] - - if values.shape[1] == 1: # homogeneous value - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - values_bi = bi % values_batch_dim - for row_i in range(shape[0]): - if events[event_bi, row_i]: - value = values[values_bi, 0] - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - res_val[bi, col_i] += value - - else: # heterogeneous values - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - for row_i in range(shape[0]): - if events[event_bi, row_i]: - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - res_val[bi, col_i] += values[value_bi, j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_csr_matvec_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, batch_size, shape, transpose = ins - batch_size = batch_size[()] - event_batch_dim = events.shape[0] - indices_batch_dim = indices.shape[0] - indptr_batch_dim = indptr.shape[0] - values_batch_dim = values.shape[0] - - if values.shape[1] == 1: # homogeneous value - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - value = values[value_bi, 0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - if events[event_bi, col_i]: - r += value - res_val[bi, row_i] = r - - else: # heterogeneous values - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - if events[event_bi, col_i]: - r += values[value_bi, j] - res_val[bi, row_i] = r - - -def _batch_event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, - batch_size, shape, transpose): - inputs = (values, indices, indptr, events) - description = dict(batch_size=batch_size, shape=shape, transpose=transpose) - if transpose: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _batch_event_csr_matvec_transpose_numba_imp, - _batch_event_csr_matvec_abstract, - False, - inputs=inputs, - description=description - ) - else: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _batch_event_csr_matvec_numba_imp, - _batch_event_csr_matvec_abstract, - False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, - name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _batch_event_csr_matvec_gpu_translation(c, values, indices, indptr, events, *, - batch_size, shape, transpose): - pass - - -def _batch_event_csr_matvec_jvp_values(values_dot, values, indices, indptr, events, *, - batch_size, shape, transpose): - return event_csr_matvec_batching_p.bind(values_dot, indices, indptr, events, - batch_size=batch_size, shape=shape, transpose=transpose) - - -def _batch_csr_matvec(values, indices, indptr, vectors, *, shape, transpose): - f = jax.vmap(partial(normal_csrmv, shape=shape, transpose=transpose), - in_axes=(0 if values.shape[0] > 1 else None, - 0 if indices.shape[0] > 1 else None, - 0 if indptr.shape[0] > 1 else None, - 0 if vectors.shape[0] > 1 else None)) - return f(values if values.shape[0] > 1 else values[0], - indices if indices.shape[0] > 1 else indices[0], - indptr if indptr.shape[0] > 1 else indptr[0], - vectors if vectors.shape[0] > 1 else vectors[0]) - - -def _batch_event_csr_matvec_jvp_events(events_dot, values, indices, indptr, events, *, - batch_size, shape, transpose): - return _batch_csr_matvec(values, indices, indptr, events_dot, - shape=shape, transpose=transpose) - - -def _batch_event_csr_matvec_transpose(ct, values, indices, indptr, events, *, - batch_size, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(events): - ct_events = ( - ad.Zero(events.aval) if type(ct) is ad.Zero else - _batch_csr_matvec(ct, indices, indptr, values, - shape=shape, transpose=not transpose) - ) - return values, indices, indptr, ct_events - else: - if values.aval.shape[1] == 1: # scalar - temp = event_csr_matvec_batching_p.bind(jnp.ones((1, 1)), indices, indptr, events, - batch_size=batch_size, shape=shape, - transpose=transpose) - ct_values = jax.vmap(jnp.inner)(ct, temp) - else: # heterogeneous values - if type(ct) is ad.Zero: - ct_values = ad.Zero(values.aval) - else: - - def _f(ct, indices, indptr, events, *, transpose): - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[col] if transpose else events[col] * ct[row] - return ct_values - - f = jax.vmap(partial(_f, transpose=transpose), - in_axes=(0, - 0 if indices.shape[0] > 1 else None, - 0 if indptr.shape[0] > 1 else None, - 0 if events.shape[0] > 1 else None)) - ct_values = f(ct, - indices if indices.shape[0] > 1 else indices[0], - indptr if indptr.shape[0] > 1 else indptr[0], - events if events.shape[0] > 1 else events[0]) - return ct_values, indices, indptr, events - - -event_csr_matvec_batching_p = Primitive('event_csr_matvec_batching') -event_csr_matvec_batching_p.def_abstract_eval(_batch_event_csr_matvec_abstract) -event_csr_matvec_batching_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_batching_p)) -# xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation -ad.defjvp(event_csr_matvec_batching_p, _batch_event_csr_matvec_jvp_values, - None, None, _batch_event_csr_matvec_jvp_events) -ad.primitive_transposes[event_csr_matvec_batching_p] = _batch_event_csr_matvec_transpose - - -# operator for `event_csr_matvec` # -# ------------------------------- # - - -def _event_csr_matvec_abstract(values, indices, indptr, events, *, shape, transpose=False): - return ShapedArray(dtype=values.dtype, shape=(shape[1] if transpose else shape[0],)) - - -@numba.njit(fastmath=True) -def _event_csr_matvec_transpose_numba_imp1_bool(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - if values.shape[0] > 1: # heter - for row_i, event in enumerate(events): - if event: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values[j] - - else: # homo - values = values[0] - for row_i, event in enumerate(events): - if event: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values - - -@numba.njit(fastmath=True) -def _event_csr_matvec_transpose_numba_imp2(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - if values.shape[0] > 1: # heter - for row_i, event in enumerate(events): - if event > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values[j] - - else: # homo - values = values[0] - for row_i, event in enumerate(events): - if event > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _event_csr_matvec_numba_imp1_bool(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - - if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i]: - r += values[j] - res_val[row_i] = r - - else: # homo - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i]: - r += values - res_val[row_i] = r - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _event_csr_matvec_numba_imp2(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - - if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i] > 0.: - r += values[j] - res_val[row_i] = r - - else: # homo - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i] > 0.: - r += values - res_val[row_i] = r - - -def _event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, shape, transpose): - inputs = (values, indices, indptr, events) - event_type = c.get_shape(events) - description = dict(shape=shape, transpose=transpose) - if transpose: - if event_type.element_type() == jnp.bool_: - imp = _event_csr_matvec_transpose_numba_imp1_bool - else: - imp = _event_csr_matvec_transpose_numba_imp2 - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - imp, - abs_eval_fn=_event_csr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - else: - if event_type.element_type() == jnp.bool_: - imp = _event_csr_matvec_numba_imp1_bool - else: - imp = _event_csr_matvec_numba_imp2 - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - imp, - abs_eval_fn=_event_csr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _event_csr_matvec_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_csr_matvec_p.name) - - # shape checking - data_shape = c.get_shape(data) - indices_shape = c.get_shape(indices) - indptr_shape = c.get_shape(indptr) - vec_shape = c.get_shape(vector) - if data_shape.element_type() == jnp.float32: - ftype = b'_float' - elif data_shape.element_type() == jnp.float64: - ftype = b'_double' - else: - raise ValueError - assert indices_shape.element_type() == indptr_shape.element_type() - if indices_shape.element_type() == jnp.int32: - itype = b'_int' - elif indices_shape.element_type() == jnp.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'_homo' if data_shape.dimensions() == (1,) else b'_heter' - tran_type = b'_transpose' if transpose else b'' - if vec_shape.element_type() == jnp.bool_: - vec_type = b'_bool' - else: - assert vec_shape.element_type() == data_shape.element_type() - vec_type = b'' - - # opaque - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - - # call - return xla_client.ops.CustomCallWithLayout( - c, - b'event_csrmv' + data_name + ftype + itype + vec_type + tran_type, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), - (shape[1] if transpose else shape[0],), - (0,)), - opaque=opaque, - ) - - -def _event_csr_matvec_batching_rule(args, axes, *, shape, transpose): - batch_size = 0 - args_processed = [] - for arg, axis in zip(args, axes): - if axis is None: - arg = jnp.expand_dims(jnp.atleast_1d(arg), 0) - else: - batch_size = arg.shape[axis] - if axis > 0: - arg = jnp.moveaxis(arg, axis, 0) - args_processed.append(arg) - - r = event_csr_matvec_batching_p.bind(*args_processed, - batch_size=batch_size, - shape=shape, - transpose=transpose) - return r, 0 - - -def _event_csr_matvec_jvp_values_brainpylib(values_dot, values, indices, indptr, events, *, shape, transpose): - return normal_csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose) - - -def _event_csr_matvec_jvp_events_brainpylib(events_dot, values, indices, indptr, events, *, shape, transpose): - return normal_csrmv(values, indices, indptr, events_dot, shape=shape, transpose=transpose) - - -def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv(values, indices, indptr, ct, shape=shape, transpose=not transpose) - return values, indices, indptr, (ad.Zero(events) if type(ct) is ad.Zero else ct_events) - else: - if type(ct) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = csrmv_brainpylib(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose) - ct_values = jnp.inner(ct, ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[col] if transpose else events[col] * ct[row] - return ct_values, indices, indptr, events - - -event_csr_matvec_p = Primitive('event_csr_matvec') -event_csr_matvec_p.def_abstract_eval(_event_csr_matvec_abstract) -event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p)) -# xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation -# xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation -ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, - _event_csr_matvec_jvp_events_brainpylib) -ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib -register_general_batching(event_csr_matvec_p) - - -# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule - - -### TAICHI ### - -def csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - Parameters ---------- data: ndarray, float @@ -691,298 +111,6 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] -# ------------- -# CPU operators -# ------------- - -# 1. The benchmarking shows that the performance of the following transpose -# kernels is maximized when using serialized mode -# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable -# arguments, we have to define each kernel separately when the -# non-differentiable/non-jittable arguments are different. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += values[j] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - -# 1. GPU kernels are different from the CPU ones, since the GPU kernels need -# to use warp-level parallelism to achieve the best performance. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -# TODO -# It is important to note that the following warp-based kernels -# should be improved, since the atomic_add for each thread is not -# very efficient. Instead, the warp-level reduction primitive -# should be used. -# see ``warp_reduce_sum()`` function in tifunc.py. -# However, currently Taichi does not support general warp-level primitives. - - -@ti.kernel -def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - def raw_csrmv_taichi( data: Union[float, jax.Array], indices: jax.Array, @@ -992,6 +120,9 @@ def raw_csrmv_taichi( shape: Tuple[int, int], transpose: bool = False ): + if ti is None: + raise PackageMissingError.by_purpose(name='taichi==1.7.0', purpose='customized operators') + if transpose: if events.dtype == jnp.bool_: if data.shape[0] == 1: @@ -1025,65 +156,361 @@ def raw_csrmv_taichi( shape=shape) -def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) +if ti is not None: + + # ------------- + # CPU operators + # ------------- + + # 1. The benchmarking shows that the performance of the following transpose + # kernels is maximized when using serialized mode + # 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable + # arguments, we have to define each kernel separately when the + # non-differentiable/non-jittable arguments are different. + + @ti.kernel + def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + + @ti.kernel + def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + + @ti.kernel + def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + + @ti.kernel + def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + + @ti.kernel + def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += value + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += values[j] + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += value + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += values[j] + out[row_i] = r + + + # ------------- + # GPU operators + # ------------- + + # 1. GPU kernels are different from the CPU ones, since the GPU kernels need + # to use warp-level parallelism to achieve the best performance. + + @ti.kernel + def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + + @ti.kernel + def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + + # TODO + # It is important to note that the following warp-based kernels + # should be improved, since the atomic_add for each thread is not + # very efficient. Instead, the warp-level reduction primitive + # should be used. + # see ``warp_reduce_sum()`` function in tifunc.py. + # However, currently Taichi does not support general warp-level primitives. + + @ti.kernel + def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + + @ti.kernel + def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + + @ti.kernel + def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive -def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) -def _event_csr_matvec_transpose_taichi( - ct, values, indices, indptr, events, *, outs, transpose, shape -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) + def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + + + def _event_csr_matvec_transpose_taichi( + ct, values, indices, indptr, events, *, outs, transpose, shape + ): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(events): + ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] + return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) else: - if values.aval.shape[0] == 1: # scalar - ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events + if type(ct[0]) is ad.Zero: + ct_values = ad.Zero(values) + else: + if values.aval.shape[0] == 1: # scalar + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = jnp.inner(ct[0], ct_values) + else: # heterogeneous values + row, col = csr_to_coo(indices, indptr) + ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] + return ct_values, indices, indptr, events -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) - prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) - return prim + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) + prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) + return prim -# transpose bool homo -_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, - _event_csr_matvec_transpose_bool_homo_gpu) + # transpose bool homo + _event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, + _event_csr_matvec_transpose_bool_homo_gpu) -# transpose homo -_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu) + # transpose homo + _event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, + _event_csr_matvec_transpose_homo_gpu) -# not transpose bool homo -_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu) + # not transpose bool homo + _event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, + _event_csr_matvec_bool_homo_gpu) -# not transpose homo -_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu) + # not transpose homo + _event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, + _event_csr_matvec_homo_gpu) -# transpose bool heter -_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, - _event_csr_matvec_transpose_bool_heter_gpu) + # transpose bool heter + _event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, + _event_csr_matvec_transpose_bool_heter_gpu) -# transpose heter -_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, - _event_csr_matvec_transpose_heter_gpu) + # transpose heter + _event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, + _event_csr_matvec_transpose_heter_gpu) -# not transpose bool heter -_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu) + # not transpose bool heter + _event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, + _event_csr_matvec_bool_heter_gpu) -# not transpose heter -_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) + # not transpose heter + _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, + _event_csr_matvec_heter_gpu) diff --git a/brainpy/_src/math/event/_info_collection.py b/brainpy/_src/math/event/_info_collection.py deleted file mode 100644 index 7bb043e3..00000000 --- a/brainpy/_src/math/event/_info_collection.py +++ /dev/null @@ -1,198 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Tuple, Union - -import jax -import numba -from jax import dtypes, numpy as jnp -from jax.core import ShapedArray -from jax.lib import xla_client - -from brainpy._src.dependency_check import import_brainpylib_gpu_ops -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register.base import XLACustomOp -from brainpy.errors import GPUOperatorNotFound - -ti = import_taichi() - -__all__ = [ - 'info' -] - - -def info(events: Union[Array, jax.Array]) -> Tuple[jax.Array, jax.Array]: - """Collect event information, including event indices, and event number. - - This function supports JAX transformations, including `jit()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - events: jax.Array - The events. - - Returns - ------- - res: tuple - A tuple with two elements, denoting the event indices and the event number. - """ - events = as_jax(events) - if events.ndim != 1: - raise TypeError('Only support 1D boolean vector.') - return event_info_p(events) - - -def _batch_event_info_abstract(events): - assert events.ndim == 2 - # assert events.dtype == jnp.bool_ - event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) - event_num = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=(events.shape[0],)) - return event_ids, event_num - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_info(outs, ins): - event_ids, event_num = outs - event_num.fill(0) - event_ids.fill(-1) - events = ins - for batch_idx in range(event_ids.shape[0]): - num = 0 - for i in range(event_ids.shape[1]): - if events[batch_idx, i]: - event_ids[batch_idx, num] = i - num += 1 - event_num[batch_idx] = num - - -@ti.kernel -def _batch_event_info_taichi(events: ti.types.ndarray(ndim=2), - event_ids: ti.types.ndarray(ndim=2), - event_num: ti.types.ndarray(ndim=1)): - for i, j in ti.grouped(ti.ndrange(event_ids.shape)): - event_ids[i, j] = -1 - for batch_idx in range(event_ids.shape[0]): - num = 0 - for i in range(event_ids.shape[1]): - if events[batch_idx, i]: - event_ids[batch_idx, num] = i - num += 1 - event_num[batch_idx] = num - - -def _batch_event_info_batching_rule(args, axes): - arg = jnp.moveaxis(args[0], axes[0], 0) - shape = arg.shape - arg = jnp.reshape(arg, (shape[0] * shape[1], shape[2])) - event_ids, event_num = batch_event_info_p(arg) - return ((jnp.reshape(event_ids, shape), jnp.reshape(event_num, shape[:2])), - (0, 0)) - - -def _event_info_gpu_translation(c, events): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_info_p.name) - - e_shape = c.get_shape(events).dimensions() - e_type = c.get_shape(events).element_type() - if len(e_shape) == 1: - event_size = e_shape[0] - batch_size = 1 - event_ids_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (event_size,), - (0,)) - else: - batch_size, event_size = e_shape - event_ids_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (batch_size, event_size), - (1, 0)) - event_num_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (batch_size,), - (0,)) - opaque = gpu_ops.build_nonzero_descriptor(event_size, batch_size) - - if e_type == jnp.bool_: - type_name = b'_bool' - elif e_type == jnp.int32: - type_name = b'_int' - elif e_type == jnp.int64: - type_name = b'_long' - elif e_type == jnp.float32: - type_name = b'_float' - elif e_type == jnp.float64: - type_name = b'_double' - else: - raise ValueError - - return xla_client.ops.CustomCallWithLayout( - c, - b'nonzero' + type_name, - operands=(events,), - operand_shapes_with_layout=(c.get_shape(events),), - shape_with_layout=xla_client.Shape.tuple_shape((event_ids_shape, event_num_shape)), - opaque=opaque, - ) - - -batch_event_info_p = XLACustomOp( - name='batched_event_info', - cpu_kernel=_batch_event_info_taichi, - gpu_kernel=_batch_event_info_taichi, - outs=_batch_event_info_abstract, -) -batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule) - - -def _event_info_abstract(events, **kwargs): - assert events.ndim == 1 - # assert events.dtype == jnp.bool_ - event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) - event_num = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=(1,)) - return event_ids, event_num - - -# TODO: first parallel evaluate the sub-sections, then serially event the sub-results. -@numba.njit(fastmath=True) -def _event_info(outs, ins): - event_ids, event_num = outs - event_num.fill(0) - event_ids.fill(-1) - events = ins - num = 0 - for i in range(event_ids.shape[0]): - if events[i]: - event_ids[num] = i - num += 1 - event_num[0] = num - - -@ti.kernel -def _event_info_taichi(events: ti.types.ndarray(ndim=1), - event_ids: ti.types.ndarray(ndim=1), - event_num: ti.types.ndarray(ndim=1)): - for i in range(event_ids.shape[0]): - event_ids[i] = -1 - num = 0 - for i in range(event_ids.shape[0]): - if events[i]: - event_ids[num] = i - num += 1 - event_num[0] = num - - -def _event_info_batching_rule(args, axes): - arg = jnp.moveaxis(args[0], axes[0], 0) - return (batch_event_info_p(arg), (0, 0)) - - -event_info_p = XLACustomOp( - name='event_info', - cpu_kernel=_event_info_taichi, - gpu_kernel=_event_info_taichi, - outs=_event_info_abstract, - # gpu_func_translation=_event_info_gpu_translation, -) -event_info_p.def_batching_rule(_event_info_batching_rule) diff --git a/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py b/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py deleted file mode 100644 index 74cc6b7f..00000000 --- a/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py +++ /dev/null @@ -1,275 +0,0 @@ -from time import time - -from jax import jit, vmap, numpy as jnp - -import brainpy.math as bm - - -def compare_argsort_and_sum(platform='cpu'): - """ - CPU - --- - - shape = (100, 10000) - brainpylib 0.1872694492340088 s - JAX argsort + sum 5.297466516494751 s - - shape = (100, 100000) - brainpylib 2.333505153656006 s - JAX argsort + sum 65.20281910896301 s - - shape = (1000, 10000) - brainpylib 2.0739688873291016 s - JAX argsort + sum 53.70602822303772 s - - shape = (10000, 1000) - brainpylib 1.7262670993804932 s - JAX argsort + sum 43.92174816131592 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.14670848846435547 s - JAX argsort + sum 1.001936435699463 s - - shape = (100, 1000000) - brainpylib 0.27660632133483887 s - JAX argsort + sum 16.390073776245117 s - - shape = (1000, 100000) - brainpylib 0.2619345188140869 s - JAX argsort + sum 9.715844869613647 s - - shape = (1000, 500000) - brainpylib 1.201209306716919 s - JAX argsort + sum 71.19761657714844 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: (jnp.argsort(events), jnp.sum(events)))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids2, event_num2 = jax_event_info(events) - assert jnp.allclose(event_num1, event_num2) - event_ids1.block_until_ready() - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a, b = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX argsort + sum {time() - t0} s') - - print() - - -def compare_argsort(platform='cpu'): - """ - - CPU - --- - - shape = (100, 10000) - brainpylib 0.19738531112670898 s - JAX argsort 5.301469087600708 s - - shape = (100, 100000) - brainpylib 2.3321938514709473 s - JAX argsort 65.13460850715637 s - - shape = (1000, 10000) - brainpylib 2.0956876277923584 s - JAX argsort 53.863110065460205 s - - shape = (10000, 1000) - brainpylib 1.7127799987792969 s - JAX argsort 44.05547475814819 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.1415419578552246 s - JAX argsort 0.9982438087463379 s - - shape = (100, 1000000) - brainpylib 0.3224947452545166 s - JAX argsort 16.504750967025757 s - - shape = (1000, 100000) - brainpylib 0.2781648635864258 s - JAX argsort 9.691488981246948 s - - shape = (1000, 500000) - brainpylib 1.2167487144470215 s - JAX argsort 71.68716263771057 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: jnp.argsort(events))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids1.block_until_ready() - event_ids2 = jax_event_info(events) - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX argsort {time() - t0} s') - - print() - - -def compare_where(platform='cpu'): - """ - - CPU - --- - - shape = (100, 10000) - brainpylib 0.20480966567993164 s - JAX where 0.7068588733673096 s - - shape = (100, 100000) - brainpylib 2.3373026847839355 s - JAX where 5.862265348434448 s - - shape = (1000, 10000) - brainpylib 2.105764865875244 s - JAX where 5.914586067199707 s - - shape = (10000, 1000) - brainpylib 1.724682331085205 s - JAX where 5.718563795089722 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.15492558479309082 s - JAX where 0.3146538734436035 s - - shape = (100, 1000000) - brainpylib 0.3290700912475586 s - JAX where 1.7064015865325928 s - - shape = (1000, 100000) - brainpylib 0.2895216941833496 s - JAX where 1.6910102367401123 s - - shape = (1000, 500000) - brainpylib 1.173649787902832 s - JAX where 7.868000268936157 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: jnp.where(events, size=events.shape[0]))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids1.block_until_ready() - event_ids2, = jax_event_info(events) - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a, = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX where {time() - t0} s') - - print() - - -if __name__ == '__main__': - # compare_argsort_and_sum('cpu') - # compare_argsort_and_sum('gpu') - # compare_argsort('cpu') - compare_argsort('gpu') - # compare_where('cpu') - # compare_where('gpu') diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index e0f38490..67e09d0a 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -4,11 +4,18 @@ from functools import partial import jax +import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + seed = 1234 diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_old.py b/brainpy/_src/math/event/tests/test_event_csrmv_old.py deleted file mode 100644 index 31a6527a..00000000 --- a/brainpy/_src/math/event/tests/test_event_csrmv_old.py +++ /dev/null @@ -1,324 +0,0 @@ -# -*- coding: utf-8 -*- - - -from functools import partial - -import jax -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm -import platform - -import pytest -pytest.skip('Old implementation.', allow_module_level=True) - -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - -brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') -taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -class Test_event_csr_matvec(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_csr_matvec, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo(self, shape, transpose, homo_data): - print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - heter_data = bm.ones(indices.shape) * homo_data - - r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = brainpylib_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - r3 = brainpylib_csr_matvec(homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r3)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r4)) - - r5 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r5)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}, homo_data={homo_data}', - transpose=transpose, - shape=shape, - homo_data=homo_data, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo_vmap(self, shape, transpose, homo_data): - print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), homo_data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose, - method='cusparse')) - vmap_data1 = bm.as_jax([homo_data] * 10) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose},shape={shape},homo_data={homo_data}', - homo_data=homo_data, - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - for homo_data in [-1., 0., 1.] - ) - def test_homo_grad(self, shape, transpose, homo_data): - print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - r3 = jax.grad(sum_op(lambda a: (events @ (dense_conn * a) if transpose else - ((dense_conn * a) @ events))))(homo_data) - self.assertTrue(bm.allclose(r1, r3)) - - # grad 'events' - r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: (e @ (dense_conn * homo_data) if transpose else - ((dense_conn * homo_data) @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f'transpose={transpose}, shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (10000, 2)] - ) - def test_heter(self, shape, transpose): - print(f'test_heter: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - heter_data = bm.as_jax(rng.random(indices.shape)) - - r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) - r2 = partial(bm.sparse.csrmv, method='cusparse')(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (events @ dense) if transpose else (dense @ events) - self.assertTrue(bm.allclose(r1, r3)) - - r4 = brainpylib_csr_matvec(heter_data, indices, indptr, events.astype(float), - shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r4)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=f"transpose={transpose}, shape={shape}", - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - ) - def test_heter_vmap(self, shape, transpose): - print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - # vmap 'data' - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, - shape=shape, transpose=transpose)) - f2 = jax.vmap( - partial(partial(bm.sparse.csrmv, method='cusparse'), indices=indices, indptr=indptr, vector=events.astype(float), - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) - - # vmap 'events' - data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, - shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(partial(bm.sparse.csrmv, method='cusparse'), data, indices, indptr, - shape=shape, transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data.astype(float)))) - - # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: partial(bm.sparse.csrmv, method='cusparse')(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) - vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) - vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 - self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2.astype(float)))) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'transpose={transpose},shape={shape}', - shape=shape, - transpose=transpose, - ) - for transpose in [True, False] - for shape in [(100, 200), - (200, 200), - (200, 100), - (10, 1000), - (2, 10000), - (1000, 10), - (100000, 2)] - ) - def test_heter_grad(self, shape, transpose): - print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - - rng = bm.random.RandomState() - indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) - - # grad 'data' - data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')))( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense_data = bm.sparse.csr_to_dense(data, indices, indptr, shape=shape) - r3 = jax.grad(sum_op(lambda a: ((events @ a) if transpose else - (a @ events))))(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r3 = r3[rows, cols] - self.assertTrue(bm.allclose(r1, r3)) - - # grad 'events' - r4 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r5 = jax.grad(sum_op(partial(bm.sparse.csrmv, method='cusparse')), argnums=3)( - data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda e: ((e @ dense_data) if transpose else - (dense_data @ e))))(events.astype(float)) - self.assertTrue(bm.allclose(r4, r5)) - self.assertTrue(bm.allclose(r4, r6)) - - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/event/tests/test_info.py b/brainpy/_src/math/event/tests/test_info.py deleted file mode 100644 index c326b0f7..00000000 --- a/brainpy/_src/math/event/tests/test_info.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp -import unittest - -import brainpy.math as bm -from jax import vmap - -import pytest - - -class Test_event_info(unittest.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_info, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - def _base_test(self, length): - print(f'{self._base_test.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(length)) < 0.1 - event_ids, event_num = bm.event.info(events) - self.assertTrue(jnp.allclose(jnp.sum(events, keepdims=True), event_num)) - - bm.clear_buffer_memory() - - def _base_vmap(self, length): - print(f'{self._base_vmap.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, length))) < 0.1 - event_ids, event_num = vmap(bm.event.info)(events) - self.assertTrue(jnp.allclose(jnp.sum(events, axis=-1), event_num)) - - bm.clear_buffer_memory() - - def _base_vmap_vmap(self, length): - print(f'{self._base_vmap_vmap.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, length))) < 0.1 - event_ids, event_num = vmap(vmap(bm.event.info))(events) - self.assertTrue(jnp.allclose(jnp.sum(events, axis=-1), event_num)) - - bm.clear_buffer_memory() - - def test(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - def test_vmap(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - def test_vmap_vmap(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - - diff --git a/brainpy/_src/math/event/tests/test_info_gpu.py b/brainpy/_src/math/event/tests/test_info_gpu.py deleted file mode 100644 index 55bdd15c..00000000 --- a/brainpy/_src/math/event/tests/test_info_gpu.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_info - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_event_info_GPU(test_info.Test_event_info): - def __init__(self, *args, **kwargs): - super(Test_event_info_GPU, self).__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/index_tricks.py b/brainpy/_src/math/index_tricks.py deleted file mode 100644 index 6c71b4b0..00000000 --- a/brainpy/_src/math/index_tricks.py +++ /dev/null @@ -1,305 +0,0 @@ -# -*- coding: utf-8 -*- - -import abc - -from jax import core -from .compat_numpy import arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose -import numpy as np - -__all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] - - -def _make_1d_grid_from_slice(s: slice, op_name: str): - start = core.concrete_or_error(None, s.start, - f"slice start of jnp.{op_name}") or 0 - stop = core.concrete_or_error(None, s.stop, - f"slice stop of jnp.{op_name}") - step = core.concrete_or_error(None, s.step, - f"slice step of jnp.{op_name}") or 1 - if np.iscomplex(step): - newobj = linspace(start, stop, int(abs(step))) - else: - newobj = arange(start, stop, step) - - return newobj - - -class _IndexGrid(abc.ABC): - """Creates multi-dimensional grids of indices.""" - sparse: bool - op_name: str - - def __getitem__(self, key): - if isinstance(key, slice): - return _make_1d_grid_from_slice(key, op_name=self.op_name) - output = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key) - output = meshgrid(*output, indexing='ij', sparse=self.sparse) - return output if self.sparse else stack(output, 0) - - -class _Mgrid(_IndexGrid): - """Return dense multi-dimensional "meshgrid". - - LAX-backend implementation of :obj:`numpy.mgrid`. This is a convenience wrapper for - functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=False``. - - See Also: - jnp.ogrid: open/sparse version of jnp.mgrid - - Examples: - Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - - >>> import brainpy.math as bm - >>> bm.mgrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) - - Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - - >>> bm.mgrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) - - Multiple slices can be used to create broadcasted grids of indices: - - >>> bm.mgrid[:2, :3] - DeviceArray([[[0, 0, 0], - [1, 1, 1]], - [[0, 1, 2], - [0, 1, 2]]], dtype=int32) - """ - sparse = False - op_name = "mgrid" - - -mgrid = _Mgrid() - - -class _Ogrid(_IndexGrid): - """Return open multi-dimensional "meshgrid". - - LAX-backend implementation of :obj:`numpy.ogrid`. This is a convenience wrapper for - functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=True``. - - See Also: - jnp.mgrid: dense version of jnp.ogrid - - Examples: - Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - - >>> bm.ogrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) - - Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - - >>> bm.ogrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) - - Multiple slices can be used to create sparse grids of indices: - - >>> bm.ogrid[:2, :3] - [DeviceArray([[0], - [1]], dtype=int32), - DeviceArray([[0, 1, 2]], dtype=int32)] - """ - sparse = True - op_name = "ogrid" - - -ogrid = _Ogrid() - - -class _AxisConcat(abc.ABC): - """Concatenates slices, scalars and array-like objects along a given axis.""" - axis: int - ndmin: int - trans1d: int - op_name: str - - def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - - params = [self.axis, self.ndmin, self.trans1d, -1] - - if isinstance(key[0], str): - # split off the directive - directive, *key = key # pytype: disable=bad-unpacking - # check two special cases: matrix directives - if directive == "r": - params[-1] = 0 - elif directive == "c": - params[-1] = 1 - else: - vec = directive.split(",") - k = len(vec) - if k < 4: - vec += params[k:] - else: - # ignore everything after the first three comma-separated ints - vec = vec[:3] + params[-1] - try: - params = list(map(int, vec)) - except ValueError as err: - raise ValueError( - "could not understand directive {!r}".format(directive) - ) from err - - axis, ndmin, trans1d, matrix = params - - output = [] - for item in key: - if isinstance(item, slice): - newobj = _make_1d_grid_from_slice(item, op_name=self.op_name) - elif isinstance(item, str): - raise ValueError("string directive must be placed at the beginning") - else: - newobj = item - - newobj = array(newobj, copy=False, ndmin=ndmin) - - if trans1d != -1 and ndmin - np.ndim(item) > 0: - shape_obj = list(range(ndmin)) - # Calculate number of left shifts, with overflow protection by mod - num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin - shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) - - newobj = transpose(newobj, shape_obj) - - output.append(newobj) - - res = concatenate(tuple(output), axis=axis) - - if matrix != -1 and res.ndim == 1: - # insert 2nd dim at axis 0 or 1 - res = expand_dims(res, matrix) - - return res - - def __len__(self): - return 0 - - -class RClass(_AxisConcat): - """Concatenate slices, scalars and array-like objects along the first axis. - - LAX-backend implementation of :obj:`numpy.r_`. - - See Also: - ``jnp.c_``: Concatenates slices, scalars and array-like objects along the last axis. - - Examples: - Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects: - - >>> bm.r_[-1:5:1, 0, 0, bm.array([1,2,3])] - DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32) - - An imaginary value for ``step`` will create a ``jnp.linspace`` object instead, - which includes the right endpoint: - - >>> bm.r_[-1:1:6j, 0, bm.array([1,2,3])] - DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005, - 0.6 , 1. , 0. , 1. , - 2. , 3. ], dtype=float32) - - Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to - specify concatenation axis, minimum number of dimensions, and the position of the - upgraded array's original dimensions in the resulting array's shape tuple: - - >>> bm.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) - - >>> bm.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - Negative values for ``trans1d`` offset the last axis towards the start - of the shape tuple: - - >>> bm.r_['0,2,-2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs - to create an array with an extra row or column axis, respectively: - - >>> bm.r_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32) - - >>> bm.r_['c',[1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"`` - give the same result. - """ - axis = 0 - ndmin = 1 - trans1d = -1 - op_name = "r_" - - -r_ = RClass() - - -class CClass(_AxisConcat): - """Concatenate slices, scalars and array-like objects along the last axis. - - LAX-backend implementation of :obj:`numpy.c_`. - - See Also: - ``jnp.r_``: Concatenates slices, scalars and array-like objects along the first axis. - - Examples: - - >>> a = bm.arange(6).reshape((2,3)) - >>> bm.c_[a,a] - DeviceArray([[0, 1, 2, 0, 1, 2], - [3, 4, 5, 3, 4, 5]], dtype=int32) - - Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify - concatenation axis, minimum number of dimensions, and the position of the upgraded array's - original dimensions in the resulting array's shape tuple: - - >>> bm.c_['0,2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - >>> bm.c_['0,2,-1', [1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) - - Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs - to create an array with inputs stacked along the last axis: - - >>> jnp.c_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 4], - [2, 5], - [3, 6]], dtype=int32) - """ - axis = -1 - ndmin = 2 - trans1d = 0 - op_name = "c_" - - -c_ = CClass() - -s_ = np.s_ - -index_exp = np.index_exp diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index a79cdc98..6f7cddf6 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,3 +1,2 @@ - -from ._matvec import * -from ._event_matvec import * \ No newline at end of file +from ._matvec import * +from ._event_matvec import * diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 3671755a..976b72b9 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -1,21 +1,14 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Tuple, Optional import jax import numpy as np -from jax import numpy as jnp, dtypes -from jax.core import ShapedArray, Primitive -from jax.interpreters import xla, ad -from jax.lib import xla_client +from jax import numpy as jnp -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p, - mv_prob_uniform_p, - mv_prob_normal_p, - mv_prob_homo, +from brainpy._src.math.jitconn._matvec import (mv_prob_homo, mv_prob_uniform, mv_prob_normal, _general_checking, @@ -27,11 +20,10 @@ _mv_prob_normal_transpose, _reverse) from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import register_general_batching, XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.op_register import XLACustomOp +from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'event_mv_prob_homo', @@ -50,746 +42,9 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -### BRAINPYLIB ### - -def event_mv_prob_homo_brainpylib( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - weight = jnp.atleast_1d(as_jax(weight)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - return r - - -event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform_brainpylib( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal_brainpylib( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ - - -def _event_matvec_prob_homo_abstract( - events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_homo_cpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_homo' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_homo_gpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_homo_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1], ) - - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, weight, clen, seed = primals - event_dot, weight_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(weight_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(event_dot) is ad.Zero: - raise ValueError - dr = mv_prob_homo_p.bind(event_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(event_dot) is ad.Zero: - dr = mv_prob_homo_p.bind(events, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - dr = mv_prob_homo_p.bind(event_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, dr - - -def _event_matvec_prob_homo_transpose( - ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -event_mv_prob_homo_p = Primitive('event_mv_prob_homo') -event_mv_prob_homo_p.multiple_results = True -event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) -event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation -ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp -ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose -register_general_batching(event_mv_prob_homo_p) - - -def _event_matvec_prob_uniform_abstract( - events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_uniform_cpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_uniform_gpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_low, w_high, clen, seed = primals - events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - outdim_parallel=outdim_parallel, - transpose=transpose) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(events_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_uniform_transpose( - ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') -event_mv_prob_uniform_p.multiple_results = True -event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) -event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation -register_general_batching(event_mv_prob_uniform_p) -ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp -ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose - - -def _event_matvec_prob_normal_abstract( - events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - assert w_mu.dtype == w_sigma.dtype - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _get_types(event_shape): - event_type = event_shape.element_type() - if event_type == jnp.bool_: - event_type = b'_bool' - out_dtype = dtypes.canonicalize_dtype(float) - elif event_type == jnp.float32: - event_type = b'_float' - out_dtype = event_shape.element_type() - elif event_type == jnp.float64: - event_type = b'_double' - out_dtype = event_shape.element_type() - else: - raise TypeError - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - return out_dtype, event_type, type_name - - -def _event_matvec_prob_normal_cpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_normal' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_normal_gpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_normal_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_mu, w_sigma, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_mu, w_sigma, clen, seed = primals - events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(events_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_normal_transpose( - ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -event_mv_prob_normal_p = Primitive('event_mv_prob_normal') -event_mv_prob_normal_p.multiple_results = True -event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) -event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation -register_general_batching(event_mv_prob_normal_p) -ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp -ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose - - -### TAICHI ### - -def event_mv_prob_homo_taichi( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ events = as_jax(events) if isinstance(weight, float): weight = as_jax(weight) weight = jnp.atleast_1d(as_jax(weight)) @@ -799,11 +54,16 @@ def event_mv_prob_homo_taichi( with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return raw_event_mv_prob_homo(events, weight, conn_len, seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ -def event_mv_prob_uniform_taichi( +def event_mv_prob_uniform( events: jax.Array, w_low: float, w_high: float, @@ -814,56 +74,9 @@ def event_mv_prob_uniform_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + events = as_jax(events) if isinstance(w_low, float): w_low = as_jax(w_low) if isinstance(w_high, float): w_high = as_jax(w_high) @@ -879,7 +92,10 @@ def event_mv_prob_uniform_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] -def event_mv_prob_normal_taichi( +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( events: jax.Array, w_mu: float, w_sigma: float, @@ -890,56 +106,9 @@ def event_mv_prob_normal_taichi( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + events = as_jax(events) if isinstance(w_mu, float): w_mu = as_jax(w_mu) if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) @@ -955,1034 +124,1036 @@ def event_mv_prob_normal_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 +event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ + +if ti is not None: + from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + + + # ------------- + # CPU function + # ------------- + # For each non-zero event value, it generates a random key using a + # function lfsr88_key and then uses this key to compute random integers + # and update the out array based on the computed indices and weight. + # + # The function is likely designed to be parallelized. + + @ti.kernel + def _event_mv_prob_homo_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col]: + r += weight0 key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: + i_col += inc + out[i_row] = r + + + # ------------- + # GPU function + # ------------- + # Contrary to the CPU functions, for each column, + # this function will 32 threads (one warp) to make + # the just-in-time random generation parallelized. + + @ti.kernel + def _event_mv_prob_homo_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison without if else + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _reverse(shape): - return shape[::-1] - - -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _reverse(shape): + return shape[::-1] + + + # ------------- + # CPU function + # ------------- + # For each non-zero event value, it generates a random key using a + # function lfsr88_key and then uses this key to compute random integers + # and update the out array based on the computed indices and weight. + # + # The function is likely designed to be parallelized. + + @ti.kernel + def _event_mv_prob_homo_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): if events[i_col] != 0.: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col] != 0.: + r += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + + # ------------- + # GPU function + # ------------- + # Contrary to the CPU functions, for each column, + # this function will 32 threads (one warp) to make + # the just-in-time random generation parallelized. + + @ti.kernel + def _event_mv_prob_homo_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 + i_col += inc + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + i_col += inc + out[i_row] += r # TODO: warp-level reduction -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction + def _event_mv_prob_homo_jvp_events( + evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(evt_dot, weight, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _event_mv_prob_homo_jvp_events( - evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(evt_dot, weight, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _event_mv_prob_homo_jvp_weight( + w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(events, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _event_mv_prob_homo_jvp_weight( - w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(events, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + + def raw_event_mv_prob_homo( + events: jax.Array, + weight: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_outdim_parallel_bool_p + else: + prim = _event_mv_prob_homo_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_bool_p + else: + prim = _event_mv_prob_homo_p + + return prim(events, + weight, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_homo_jvp_events, + _event_mv_prob_homo_jvp_weight, + None, + None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_bool_cpu, + gpu_kernel=_event_mv_prob_homo_bool_gpu + ) -def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu + ) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_cpu, + gpu_kernel=_event_mv_prob_homo_gpu + ) -def raw_event_mv_prob_homo( - events: jax.Array, - weight: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_outdim_parallel_bool_p - else: - prim = _event_mv_prob_homo_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_bool_p - else: - prim = _event_mv_prob_homo_p - - return prim(events, - weight, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_homo_jvp_events, - _event_mv_prob_homo_jvp_weight, - None, - None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_bool_cpu, - gpu_kernel=_event_mv_prob_homo_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_cpu, - gpu_kernel=_event_mv_prob_homo_gpu -) - - -@ti.kernel -def _event_mv_prob_uniform_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + @ti.kernel + def _event_mv_prob_uniform_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + if events[i_col]: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_uniform_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_uniform_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_uniform_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + if events[i_col] != 0.: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_uniform_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_uniform_jvp_events( - evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_uniform_jvp_w_low( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _event_mv_prob_uniform_jvp_events( + evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_uniform_jvp_w_low( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_uniform_jvp_w_high( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def raw_event_mv_prob_uniform( + events: jax.Array, + w_low: jax.Array, # vector with size 1 + w_high: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_outdim_parallel_bool_p + else: + prim = _event_mv_prob_uniform_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_bool_p + else: + prim = _event_mv_prob_uniform_p + + return prim(events, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_uniform_jvp_events, + _event_mv_prob_uniform_jvp_w_low, + _event_mv_prob_uniform_jvp_w_high, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_bool_gpu + ) -def _event_mv_prob_uniform_jvp_w_high( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu + ) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_cpu, + gpu_kernel=_event_mv_prob_uniform_gpu + ) -def raw_event_mv_prob_uniform( - events: jax.Array, - w_low: jax.Array, # vector with size 1 - w_high: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_outdim_parallel_bool_p - else: - prim = _event_mv_prob_uniform_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_bool_p - else: - prim = _event_mv_prob_uniform_p - - return prim(events, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_uniform_jvp_events, - _event_mv_prob_uniform_jvp_w_low, - _event_mv_prob_uniform_jvp_w_high, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_cpu, - gpu_kernel=_event_mv_prob_uniform_gpu -) - - -@ti.kernel -def _event_mv_prob_normal_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + @ti.kernel + def _event_mv_prob_normal_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + if events[i_col]: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_normal_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_normal_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_normal_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + if events[i_col] != 0.: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_normal_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_normal_jvp_events( - evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_mu( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_sigma( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _event_mv_prob_normal_jvp_events( + evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_normal_jvp_w_mu( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_normal_jvp_w_sigma( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def raw_event_mv_prob_normal( + events: jax.Array, + w_mu: jax.Array, # vector with size 1 + w_sigma: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_outdim_parallel_bool_p + else: + prim = _event_mv_prob_normal_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_bool_p + else: + prim = _event_mv_prob_normal_p + + return prim(events, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_normal_jvp_events, + _event_mv_prob_normal_jvp_w_mu, + _event_mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_bool_cpu, + gpu_kernel=_event_mv_prob_normal_bool_gpu + ) -def raw_event_mv_prob_normal( - events: jax.Array, - w_mu: jax.Array, # vector with size 1 - w_sigma: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu + ) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_outdim_parallel_bool_p - else: - prim = _event_mv_prob_normal_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_bool_p - else: - prim = _event_mv_prob_normal_p - - return prim(events, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_normal_jvp_events, - _event_mv_prob_normal_jvp_w_mu, - _event_mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_bool_cpu, - gpu_kernel=_event_mv_prob_normal_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_cpu, - gpu_kernel=_event_mv_prob_normal_gpu -) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_cpu, + gpu_kernel=_event_mv_prob_normal_gpu + ) diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 0caa9c99..00e5778f 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -1,24 +1,20 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Tuple, Optional, Union import jax import numpy as np -from jax import numpy as jnp, dtypes -from jax.core import ShapedArray, Primitive -from jax.interpreters import xla, ad -from jax.lib import xla_client +from jax import numpy as jnp +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_register import register_general_batching, XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.op_register import XLACustomOp +from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'mv_prob_homo', @@ -85,8 +81,22 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_homo_taichi(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + vector = as_jax(vector) + if isinstance(weight, float): + weight = as_jax(weight, dtype=vector.dtype) + weight = jnp.atleast_1d(as_jax(weight)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.asarray(seed, dtype=jnp.uint32) + seed = jnp.atleast_1d(seed) + return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] def mv_prob_uniform( @@ -150,8 +160,22 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + + vector = as_jax(vector) + if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) + if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] def mv_prob_normal( @@ -215,1188 +239,110 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + vector = as_jax(vector) + if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) + if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 + conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) + if seed is None: + with jax.ensure_compile_time_eval(): + seed = np.random.randint(0, int(1e8), 1) + seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) + return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel)[0] -### BRAINYPLIB ### -def mv_prob_homo_brainpylib( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, +def raw_mv_prob_homo( + vector: jax.Array, + weight: jax.Array, # vector with size 1 + clen: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 *, shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: + mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). + if outdim_parallel: + prim = _mv_prob_homo_outdim_parallel_p + else: + prim = _mv_prob_homo_p - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. + return prim(vector, + weight, + clen, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - weight = jnp.atleast_1d(as_jax(weight)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_homo_p.bind(vector, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - )[0] - - -def mv_prob_uniform_brainpylib( +def raw_mv_prob_uniform( vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, *, shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). + if outdim_parallel: + prim = _mv_prob_uniform_outdim_parallel_p + else: + prim = _mv_prob_uniform_p - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. + return prim(vector, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_uniform_p.bind(vector, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def mv_prob_normal_brainpylib( +def raw_mv_prob_normal( vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, *, shape: Tuple[int, int], transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. + if outdim_parallel: + prim = _mv_prob_normal_outdim_parallel_p + else: + prim = _mv_prob_normal_p - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. + return prim(vector, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_normal_p.bind(vector, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def _matvec_prob_homo_abstract( - vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - if transpose: - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({vector.shape[0]},) @ mat {shape}.') - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_homo_cpu_translation( - c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - out_type = b'_float' - elif out_dtype == jnp.float64: - out_type = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_homo' + out_type - else: - fn = b'cpu_matvec_atomic_prob_homo' + out_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_homo_gpu_translation( - c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_homo_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_homo_v2' + type_name - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, weight, clen, seed = primals - vector_dot, weight_dot, clen_dot, seed_dot = tangents - r = mv_prob_homo_p.bind(vector, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(vector_dot) is ad.Zero: - raise ValueError - r_dot = mv_prob_homo_p.bind(vector_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(vector_dot) is ad.Zero: - r_dot = mv_prob_homo_p.bind(vector, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - r_dot = mv_prob_homo_p.bind(vector_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - return r, r_dot - - -def _matvec_prob_homo_transpose( - ct, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - assert type(vector) is ad.UndefinedPrimal - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -mv_prob_homo_p = Primitive('matvec_prob_homo') -mv_prob_homo_p.multiple_results = True -mv_prob_homo_p.def_abstract_eval(_matvec_prob_homo_abstract) -mv_prob_homo_p.def_impl(partial(xla.apply_primitive, mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation -register_general_batching(mv_prob_homo_p) -ad.primitive_jvps[mv_prob_homo_p] = _matvec_prob_homo_jvp -ad.primitive_transposes[mv_prob_homo_p] = _matvec_prob_homo_transpose - - -def _matvec_prob_uniform_abstract( - vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype == vector.dtype - - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_uniform_cpu_translation( - c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_uniform' + type_name - else: - fn = b'cpu_matvec_atomic_prob_uniform' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_uniform_gpu_translation( - c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError(f'Only support float or double, while got {out_dtype}') - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_uniform_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_uniform_v2' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, w_low, w_high, clen, seed = primals - vector_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = mv_prob_uniform_p.bind(vector, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(vector_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _matvec_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(vector) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -mv_prob_uniform_p = Primitive('matvec_prob_uniform') -mv_prob_uniform_p.multiple_results = True -mv_prob_uniform_p.def_abstract_eval(_matvec_prob_uniform_abstract) -mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation -register_general_batching(mv_prob_uniform_p) -ad.primitive_jvps[mv_prob_uniform_p] = _matvec_prob_uniform_jvp -ad.primitive_transposes[mv_prob_uniform_p] = _matvec_prob_uniform_transpose - - -def _matvec_prob_normal_abstract( - vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _w_sigma_dtype in [jnp.float32, jnp.float64], '"w_sigma" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_normal_cpu_translation( - c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_normal' + type_name - else: - fn = b'cpu_matvec_atomic_prob_normal' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_normal_gpu_translation( - c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - event_shape = c.get_shape(vector) - out_dtype = event_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError(f'Only support float or double, while got {out_dtype}') - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_normal_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_normal_v2' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_mu, - w_sigma, - clen, - seed,), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, w_mu, w_sigma, clen, seed = primals - vector_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = mv_prob_normal_p.bind(vector, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(vector_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _matvec_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(vector) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -mv_prob_normal_p = Primitive('matvec_prob_normal') -mv_prob_normal_p.multiple_results = True -mv_prob_normal_p.def_abstract_eval(_matvec_prob_normal_abstract) -mv_prob_normal_p.def_impl(partial(xla.apply_primitive, mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation -register_general_batching(mv_prob_normal_p) -ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp -ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose - - -### TAICHI ### -def mv_prob_homo_taichi( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Generally, the :math:`M` in ``f(outdim_parallel=True, transpose=False)`` is the same of - the :math:`M^T` used in ``f(outdim_parallel=False, transpose=True)``. - - Similarly, the :math:`M^T` in ``f(outdim_parallel=True, transpose=True)`` is the same - of the :math:`M` used in ``f(outdim_parallel=False, transpose=False)``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(weight, float): - weight = as_jax(weight, dtype=vector.dtype) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.asarray(seed, dtype=jnp.uint32) - seed = jnp.atleast_1d(seed) - return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def mv_prob_uniform_taichi( - vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) - if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def mv_prob_normal_taichi( - vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - - -def _reverse(shape): - return shape[::-1] - - -@ti.kernel -def _mv_prob_homo_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - v = vector[i_col] * weight0 - while i_row < num_row: - out[i_row] += v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r * weight0 - - -@ti.kernel -def _mv_prob_homo_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += weight0 * r # TODO: warp-level reduction - - -def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed - else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed - else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): +def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): if vector.ndim != 1: raise ValueError('vector should be a 1D vector.') if len(shape) != 2: @@ -1437,190 +383,28 @@ def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, * return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) -def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) - -# outdim_parallel = False -_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, - gpu_kernel=_mv_prob_homo_gpu) - - -@ti.kernel -def _mv_prob_uniform_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_uniform_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_uniform_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_uniform_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) +def _mv_prob_homo_transpose( + ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), weight, clen, seed + else: + dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, weight, clen, seed + elif ad.is_undefined_primal(weight): + if type(ct) is ad.Zero: + return vector, ad.Zero(weight), clen, seed + else: + row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + dw = jnp.sum(row * vector, keepdims=True) + return vector, dw, clen, seed + else: + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' def _mv_prob_uniform_transpose( @@ -1641,265 +425,496 @@ def _mv_prob_uniform_transpose( assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' -def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p +def _mv_prob_normal_transpose( + ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel +): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_mu, w_sigma, clen, seed + else: + dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_mu, w_sigma, clen, seed else: - prim = _mv_prob_uniform_p + assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' + assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) +def _reverse(shape): + return shape[::-1] -def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_uniform_jvp_vector, - _mv_prob_uniform_jvp_wlow, - _mv_prob_uniform_jvp_whigh, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_uniform_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_cpu, - gpu_kernel=_mv_prob_uniform_gpu -) - - -@ti.kernel -def _mv_prob_normal_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += col_v * raw_v + +if ti is not None: + from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + + + @ti.kernel + def _mv_prob_homo_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + v = vector[i_col] * weight0 + while i_row < num_row: + out[i_row] += v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_homo_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r * weight0 + + + @ti.kernel + def _mv_prob_homo_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * raw_v + while i_row < end: + out[i_row] += weight0 * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_homo_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] = r + while i_col < end_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += weight0 * r # TODO: warp-level reduction -@ti.kernel -def _mv_prob_normal_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) -@ti.kernel -def _mv_prob_normal_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * row_v + def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) + + # outdim_parallel = False + _mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, + gpu_kernel=_mv_prob_homo_gpu) + + + @ti.kernel + def _mv_prob_uniform_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_uniform_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _mv_prob_uniform_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_uniform_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += r # TODO: warp-level reduction + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction -def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed - else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_uniform_jvp_vector, + _mv_prob_uniform_jvp_wlow, + _mv_prob_uniform_jvp_whigh, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim -def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + # outdim_parallel = True + _mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu + ) - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p + # outdim_parallel = False + _mv_prob_uniform_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_cpu, + gpu_kernel=_mv_prob_uniform_gpu + ) - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) + @ti.kernel + def _mv_prob_normal_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_normal_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _mv_prob_normal_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_normal_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_normal_jvp_vector, - _mv_prob_normal_jvp_w_mu, - _mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_normal_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_cpu, - gpu_kernel=_mv_prob_normal_gpu -) + + def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_normal_jvp_vector, + _mv_prob_normal_jvp_w_mu, + _mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_mv_prob_normal_outdim_parallel_gpu + ) + + # outdim_parallel = False + _mv_prob_normal_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_cpu, + gpu_kernel=_mv_prob_normal_gpu + ) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index b10d55d2..d8e08654 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -4,8 +4,14 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +import pytest import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 2e6e406c..8a0ae444 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -4,8 +4,13 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +import pytest import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index ad8a5ccf..f5e09167 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -28,8 +28,10 @@ get_stack_cache, cache_stack) from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, VariableStack) -from .tools import eval_shape +from .variables import (Variable, + VariableStack, + current_transform_number, + new_transform) __all__ = [ 'grad', # gradient of scalar function @@ -201,21 +203,36 @@ def __call__(self, *args, **kwargs): elif not self._eval_dyn_vars: # evaluate dynamical variables stack = get_stack_cache(self.target) if stack is None: - with VariableStack() as stack: - rets = eval_shape(self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs) + with new_transform(self): + with VariableStack() as stack: + if current_transform_number() > 1: + rets = self._transform( + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) + else: + rets = jax.eval_shape( + self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs + ) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if not stack.is_first_stack(): - return self._return(rets) + # if not the outermost transformation + if current_transform_number(): + return self._return(rets) + else: + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index c52845a0..aaf053ae 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -6,6 +6,7 @@ """ import numbers +import os import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional @@ -13,13 +14,14 @@ import jax import numpy as np -from brainpy._src.math.modes import Mode +from brainpy import errors from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) +from brainpy._src.math.modes import Mode from brainpy._src.math.sharding import BATCH_AXIS variable_ = None diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 3edeb08e..032a0fab 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -21,12 +21,17 @@ cache_stack ) from .tools import ( - eval_shape, + evaluate_dyn_vars, dynvar_deprecation, node_deprecation, abstract ) -from .variables import (Variable, VariableStack) +from .variables import ( + Variable, + VariableStack, + new_transform, + current_transform_number, +) __all__ = [ 'make_loop', @@ -537,13 +542,15 @@ def cond( node_deprecation(child_objs) dyn_vars = get_stack_cache((true_fun, false_fun)) - if not jax.config.jax_disable_jit and dyn_vars is None: - with VariableStack() as dyn_vars: - rets = eval_shape(true_fun, *operands, with_stack=True)[1] - _ = eval_shape(false_fun, *operands, with_stack=True) - cache_stack((true_fun, false_fun), dyn_vars) - if not dyn_vars.is_first_stack(): - return rets + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('cond'): + dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars = dyn_vars1 + dyn_vars2 + cache_stack((true_fun, false_fun), dyn_vars) + if current_transform_number() > 0: + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -674,16 +681,20 @@ def ifelse( else: dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: - with VariableStack() as dyn_vars: - rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] - trees = [jax.tree_util.tree_structure(ret) for ret in rets] - if not _all_equal(trees): - msg = 'All returns in branches should have the same tree structure. But we got:\n' - for tree in trees: - msg += f'- {tree}\n' - raise TypeError(msg) + with new_transform('ifelse'): + with VariableStack() as dyn_vars: + if current_transform_number() > 1: + rets = [branch(*operands) for branch in branches] + else: + rets = [jax.eval_shape(branch, *operands) for branch in branches] + trees = [jax.tree_util.tree_structure(ret) for ret in rets] + if not _all_equal(trees): + msg = 'All returns in branches should have the same tree structure. But we got:\n' + for tree in trees: + msg += f'- {tree}\n' + raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) - if not dyn_vars.is_first_stack(): + if current_transform_number(): return rets[0] branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches] @@ -869,23 +880,28 @@ def for_loop( if jit is None: # jax disable jit jit = not jax.config.jax_disable_jit - stack = get_stack_cache((body_fun, unroll_kwargs)) + dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) if jit: - if stack is None: - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, - remat, reverse, unroll, unroll_kwargs) + if dyn_vars is None: # TODO: better cache mechanism? - with VariableStack() as stack: - rets = eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), stack) # cache - if not stack.is_first_stack(): + with new_transform('for_loop'): + with VariableStack() as dyn_vars: + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, + progress_bar, remat, reverse, unroll, + unroll_kwargs) + if current_transform_number() > 1: + rets = transform(operands) + else: + rets = jax.eval_shape(transform, operands) + cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache + if current_transform_number(): return rets[1] del rets else: - stack = VariableStack() + dyn_vars = VariableStack() # TODO: cache mechanism? - transform = _get_for_loop_transform(body_fun, stack, bar, + transform = _get_for_loop_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll, unroll_kwargs) if jit: @@ -893,11 +909,11 @@ def for_loop( else: with jax.disable_jit(): dyn_vals, out_vals = transform(operands) - for key in stack.keys(): - stack[key]._value = dyn_vals[key] + for key in dyn_vars.keys(): + dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() - del dyn_vals, stack + del dyn_vals, dyn_vars return out_vals @@ -995,21 +1011,26 @@ def scan( num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) bar = tqdm(total=num_total) - stack = get_stack_cache(body_fun) - if not jax.config.jax_disable_jit and stack is None: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - with VariableStack() as stack: - rets = eval_shape(transform, init, operands) - cache_stack(body_fun, stack) # cache - if not stack.is_first_stack(): - return rets[0][1], rets[1] - del rets - - stack = VariableStack() if stack is None else stack - transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) + dyn_vars = get_stack_cache(body_fun) + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('scan'): + with VariableStack() as dyn_vars: + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + if current_transform_number() > 1: + rets = transform(init, operands) + else: + rets = jax.eval_shape(transform, init, operands) + cache_stack(body_fun, dyn_vars) # cache + if current_transform_number(): + return rets[0][1], rets[1] + del rets + + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars + transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) (dyn_vals, carry), out_vals = transform(init, operands) - for key in stack.keys(): - stack[key]._value = dyn_vals[key] + for key in dyn_vars.keys(): + dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() return carry, out_vals @@ -1108,6 +1129,7 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. + """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1115,16 +1137,18 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - stack = get_stack_cache((body_fun, cond_fun)) - if not jax.config.jax_disable_jit and stack is None: - with VariableStack() as stack: - _ = eval_shape(cond_fun, *operands, with_stack=True) - rets = eval_shape(body_fun, *operands, with_stack=True)[1] - cache_stack((body_fun, cond_fun), stack) - if not stack.is_first_stack(): - return rets - stack = VariableStack() if stack is None else stack - dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) - for k, v in stack.items(): + dyn_vars = get_stack_cache((body_fun, cond_fun)) + if not jax.config.jax_disable_jit: + if dyn_vars is None: + with new_transform('while_loop'): + dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1) + dyn_vars = dyn_vars1 + dyn_vars2 + cache_stack((body_fun, cond_fun), dyn_vars) + if current_transform_number(): + return rets + dyn_vars = VariableStack() if dyn_vars is None else dyn_vars + dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) + for k, v in dyn_vars.items(): v._value = dyn_values[k] return out diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 73eab2f9..7bb36f4e 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,15 +11,23 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax +from jax.sharding import Sharding from brainpy import tools, check -from .base import BrainPyObject, ObjectTransform -from .naming import get_stack_cache, cache_stack from .tools import (dynvar_deprecation, node_deprecation, - eval_shape) -from .variables import (Variable, VariableStack) + evaluate_dyn_vars_with_cache, + evaluate_dyn_vars, + _partial_fun) +from .base import BrainPyObject, ObjectTransform +from .naming import get_stack_cache, cache_stack from ..ndarray import Array +from .variables import (Variable, + VariableStack, + outermost_transform, + transform_stack, + current_transform_number, + new_transform) RandomState = None @@ -143,12 +151,16 @@ def _transform_function(self, variable_data: Dict, *args, **kwargs): return changes, out def _get_transform(self, *args, **kwargs): - with VariableStack() as self._dyn_vars: - rets = eval_shape(self.fun, - *args, - **kwargs, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames) + with new_transform(self): + self._dyn_vars, rets = evaluate_dyn_vars( + self.fun, + *args, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames, + use_eval_shape=current_transform_number() <= 1, + **kwargs + ) + # in_shardings if self._in_shardings is None: in_shardings = None @@ -174,18 +186,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + self._transform_function, + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs): @@ -195,7 +207,7 @@ def __call__(self, *args, **kwargs): if self._transform is None: # initialize the transformation rets = self._get_transform(*args, **kwargs) # if not the outermost transformation - if not self._dyn_vars.is_first_stack(): + if current_transform_number(): return rets # call the transformed function @@ -465,8 +477,15 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - with VariableStack() as stack: - _ = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + + with jax.ensure_compile_time_eval(): + if len(static_argnums) or len(static_argnames): + fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) + else: + args_, kwargs_, fun3 = args, kwargs, fun2 + with VariableStack() as stack: + _ = jax.eval_shape(fun3, *args_, **kwargs_) + del args_, kwargs_ _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1181e003..1c8ca6ef 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -41,7 +41,7 @@ def get_unique_name(type_: str): return name -def clear_name_cache(ignore_warn=True): +def clear_name_cache(ignore_warn=False): """Clear the cached names.""" _name2id.clear() _typed_names.clear() @@ -57,7 +57,6 @@ def cache_stack(func, stack): def clear_stack_cache(): - """Clear the cached stack.""" for k in tuple(_fun2stack.keys()): del _fun2stack[k] diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py new file mode 100644 index 00000000..1eddce04 --- /dev/null +++ b/brainpy/_src/math/object_transform/parallels.py @@ -0,0 +1,460 @@ +# -*- coding: utf-8 -*- + +""" +The parallel compilation tools for JAX backend. + +1. Vectorize compilation is implemented by the 'vmap()' function +2. Parallel compilation is implemented by the 'pmap()' function + +""" + + +import functools + +import jax +import jax.numpy as jnp +import numpy as np +from jax.interpreters.partial_eval import DynamicJaxprTracer +from jax.interpreters.partial_eval import JaxprTracer +from jax.interpreters.pxla import ShardedDeviceArray + +try: + from jax.errors import UnexpectedTracerError +except ImportError: + from jax.core import UnexpectedTracerError + +from brainpy import errors +from brainpy._src.math.random import RandomState +from brainpy._src.math.ndarray import Array +from brainpy.tools import change_func_name +from .base import BrainPyObject, ArrayCollector + +__all__ = [ + 'vmap', + 'pmap', +] + + +def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, + batch_idx, axis_name, f_name=None): + @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) + def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): + nonbatched_vars.assign(nonbatched_data) + batched_vars.assign(batched_data) + out = func(*args, **kwargs) + nonbatched_changes = nonbatched_vars.dict() + batched_changes = batched_vars.dict() + return nonbatched_changes, batched_changes, out + + def call(*args, **kwargs): + n = args[batch_idx[0]].shape[batch_idx[1]] + nonbatched_data = nonbatched_vars.dict() + batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} + try: + out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) + except UnexpectedTracerError as e: + nonbatched_vars.assign(nonbatched_data) + batched_vars.assign(batched_data) + raise errors.JaxTracerError() from e + # for key, v in dyn_changes.items(): + # dyn_vars[key] = reduce_func(v) + # for key, v in rand_changes.items(): + # rand_vars[key] = reduce_func(v) + return out + + return change_func_name(name=f_name, f=call) if f_name else call + + +def vmap(func, dyn_vars=None, batched_vars=None, + in_axes=0, out_axes=0, axis_name=None, + reduce_func=None, auto_infer=False): + """Vectorization compilation for class objects. + + Vectorized compile a function or a module to run in parallel on a single device. + + Examples + -------- + + Parameters + ---------- + func : BrainPyObject, function, callable + The function or the module to compile. + dyn_vars : dict, sequence + batched_vars : dict + in_axes : optional, int, sequence of int + Specify which input array axes to map over. If each positional argument to + ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, + or a tuple of integers and Nones with length equal to the number of + positional arguments to ``obj_or_func``. An integer or ``None`` + indicates which array axis to map over for all arguments (with ``None`` + indicating not to map any axis), and a tuple indicates which axis to map + for each corresponding positional argument. Axis integers must be in the + range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of + dimensions (axes) of the corresponding input array. + + If the positional arguments to ``obj_or_func`` are container types, the + corresponding element of ``in_axes`` can itself be a matching container, + so that distinct array axes can be mapped for different container + elements. ``in_axes`` must be a container tree prefix of the positional + argument tuple passed to ``obj_or_func``. + + At least one positional argument must have ``in_axes`` not None. The sizes + of the mapped input axes for all mapped positional arguments must all be + equal. + + Arguments passed as keywords are always mapped over their leading axis + (i.e. axis index 0). + out_axes : optional, int, tuple/list/dict + Indicate where the mapped axis should appear in the output. All outputs + with a mapped axis must have a non-None ``out_axes`` specification. Axis + integers must be in the range ``[-ndim, ndim)`` for each output array, + where ``ndim`` is the number of dimensions (axes) of the array returned + by the :func:`vmap`-ed function, which is one more than the number of + dimensions (axes) of the corresponding array returned by ``obj_or_func``. + axis_name : optional + + Returns + ------- + obj_or_func : Any + Batched/vectorized version of ``obj_or_func`` with arguments that correspond to + those of ``obj_or_func``, but with extra array axes at positions indicated by + ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but + with extra array axes at positions indicated by ``out_axes``. + + """ + # if isinstance(func, DynamicalSystem): + # if len(func.steps): # DynamicalSystem has step functions + # + # # dynamical variables + # dyn_vars = (dyn_vars or func.vars().unique()) + # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector() + # for key, val in dyn_vars.items(): + # if isinstance(val, RandomState): + # rand_vars[key] = val + # else: + # dyn_vars[key] = val + # + # # in axes + # if in_axes is None: + # in_axes = {key: (None, 0) for key in func.steps.keys()} + # elif isinstance(in_axes, int): + # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()} + # elif isinstance(in_axes, (tuple, list)): + # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()} + # elif isinstance(in_axes, dict): + # keys = list(func.steps.keys()) + # if keys[0] not in in_axes: + # in_axes = {key: (None, 0, in_axes) for key in keys} + # else: + # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys} + # assert isinstance(in_axes, dict) + # + # # batch size index + # batch_idx = {} + # for key, axes in in_axes.items(): + # for i, axis in enumerate(axes[2:]): + # if axis is not None: + # batch_idx[key] = (i, axis) + # break + # else: + # raise ValueError(f'Found no batch axis: {axes}.') + # + # # out axes + # if out_axes is None: + # out_axes = {key: 0 for key in func.steps.keys()} + # elif isinstance(out_axes, int): + # out_axes = {key: out_axes for key in func.steps.keys()} + # elif isinstance(out_axes, (tuple, list)): + # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()} + # elif isinstance(out_axes, dict): + # keys = list(func.steps.keys()) + # if keys[0] not in out_axes: + # out_axes = {key: (out_axes, 0, 0) for key in keys} + # else: + # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys} + # assert isinstance(out_axes, dict) + # + # # reduce_func + # if reduce_func is None: + # reduce_func = lambda x: x.mean(axis=0) + # + # # vectorized map functions + # for key in func.steps.keys(): + # func.steps[key] = _make_vmap(func=func.steps[key], + # dyn_vars=dyn_vars, + # rand_vars=rand_vars, + # in_axes=in_axes[key], + # out_axes=out_axes[key], + # axis_name=axis_name, + # batch_idx=batch_idx[key], + # reduce_func=reduce_func, + # f_name=key) + # + # return func + + if callable(func): + if auto_infer: + if dyn_vars is not None: + dyn_vars = dyn_vars + elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation + dyn_vars = func.vars().unique() + elif hasattr(func, '__self__'): + if isinstance(func.__self__, BrainPyObject): + dyn_vars = func.__self__.vars().unique() + + if dyn_vars is None: + return jax.vmap(func, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name) + + else: + if isinstance(dyn_vars, Array): + dyn_vars = [dyn_vars] + if isinstance(dyn_vars, (tuple, list)): + dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} + assert isinstance(dyn_vars, dict) + + # dynamical variables + _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector() + for key, val in dyn_vars.items(): + if isinstance(val, RandomState): + _rand_vars[key] = val + else: + _dyn_vars[key] = val + + # in axes + if in_axes is None: + in_axes = (None, 0) + elif isinstance(in_axes, (int, dict)): + in_axes = (None, 0, in_axes) + elif isinstance(in_axes, (tuple, list)): + in_axes = (None, 0) + tuple(in_axes) + assert isinstance(in_axes, (tuple, list)) + + # batch size index + batch_idx = {} + for key, axes in batch_idx.items(): + for i, axis in enumerate(axes[2:]): + if axis is not None: + batch_idx[key] = (i, axis) + break + else: + raise ValueError(f'Found no batch axis: {axes}.') + + # out axes + if out_axes is None: + out_axes = 0 + elif isinstance(out_axes, (int, dict)): + out_axes = (out_axes, 0, 0) + elif isinstance(out_axes, (tuple, list)): + out_axes = tuple(out_axes) + (0, 0) + assert isinstance(out_axes, (list, tuple)) + + # reduce_func + if reduce_func is None: + reduce_func = lambda x: x.mean(axis=0) + + # jit function + return _make_vmap(func=func, + nonbatched_vars=_dyn_vars, + batched_vars=_rand_vars, + in_axes=in_axes, + out_axes=out_axes, + axis_name=axis_name, + batch_idx=batch_idx) + + else: + raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable ' + f'function, but we got {type(func)}.') + + +def _device_reshape(x): + """Reshape an input array in order to broadcast to multiple devices.""" + num_device = jax.local_device_count() + + if not hasattr(x, 'ndim'): + raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to ' + f'parallel, first convert it to a Array, for example np.float(0.5)') + if x.ndim == 0: + return np.broadcast_to(x, [num_device]) + if x.shape[0] % num_device != 0: + raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among ' + f'{num_device} devices, but does not go equally.') + return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:]) + + +def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0, + out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, + axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None): + @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, + static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, + backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes) + def pmapped_func(dyn_data, rand_data, *args, **kwargs): + dyn_vars.assign(dyn_data) + rand_vars.assign(rand_data) + out = func(*args, **kwargs) + dyn_changes = dyn_vars.dict() + rand_changes = rand_vars.dict() + return out, dyn_changes, rand_changes + + def call(*args): + un_replicated = [k for k, v in dyn_vars.items() + if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))] + if len(un_replicated): + raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.' + f'did you forget to call xx.replicate() on them?') + _args = [] + for i, x in enumerate(args): + if i + 2 in static_broadcasted_argnums: + _args.append(x) + else: + _args.append(jax.tree_map(_device_reshape, [x])[0]) + dyn_data = dyn_vars.dict() + rand_data = rand_vars.dict() + output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args) + dyn_vars.assign(dyn_changes) + rand_vars.assign(rand_changes) + return jax.tree_map(reduce_func, output) + + return change_func_name(name=f_name, f=call) if f_name else call + + +def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), + devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, + reduce_func=None): + """Parallel compilation for class objects. + + Parallel compile a function or a module to run on multiple devices in parallel. + + Parameters + ---------- + func + axis_name + in_axes + out_axes + static_broadcasted_argnums + devices + backend + axis_size + donate_argnums + global_arg_shapes + + Returns + ------- + + + Examples + -------- + + + """ + + # if isinstance(func, DynamicalSystem): + # if len(func.steps): # DynamicalSystem has step functions + # + # # dynamical variables + # all_vars = (dyn_vars or func.vars().unique()) + # dyn_vars = ArrayCollector() + # rand_vars = ArrayCollector() + # for key, val in all_vars.items(): + # if isinstance(val, RandomState): + # rand_vars[key] = val + # else: + # dyn_vars[key] = val + # + # # reduce function + # if reduce_func is None: + # reduce_func = jnp.concatenate + # + # # static broadcast-ed arguments + # if static_broadcasted_argnums is None: + # static_broadcasted_argnums = () + # elif isinstance(static_broadcasted_argnums, int): + # static_broadcasted_argnums = (static_broadcasted_argnums + 2,) + # elif isinstance(static_broadcasted_argnums, (tuple, list)): + # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) + # assert isinstance(static_broadcasted_argnums, (tuple, list)) + # + # # jit functions + # for key in func.steps.keys(): + # step = func.steps[key] + # func.steps[key] = _make_pmap(dyn_vars=dyn_vars, + # rand_vars=rand_vars, + # func=step, + # axis_name=axis_name, + # in_axes=in_axes, + # out_axes=out_axes, + # static_broadcasted_argnums=static_broadcasted_argnums, + # devices=devices, + # backend=backend, + # axis_size=axis_size, + # donate_argnums=donate_argnums, + # global_arg_shapes=global_arg_shapes, + # reduce_func=reduce_func, + # f_name=key) + # return func + + if callable(func): + if dyn_vars is not None: + dyn_vars = dyn_vars + elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation + dyn_vars = func.vars().unique() + elif hasattr(func, '__self__'): + if isinstance(func.__self__, BrainPyObject): + dyn_vars = func.__self__.vars().unique() + + if dyn_vars is None: + return jax.pmap(func, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes) + else: + # dynamical variables + dyn_vars = ArrayCollector() + rand_vars = ArrayCollector() + for key, val in dyn_vars.items(): + if isinstance(val, RandomState): + rand_vars[key] = val + else: + dyn_vars[key] = val + + # static broadcast-ed arguments + if static_broadcasted_argnums is None: + static_broadcasted_argnums = () + elif isinstance(static_broadcasted_argnums, int): + static_broadcasted_argnums = (static_broadcasted_argnums + 2,) + elif isinstance(static_broadcasted_argnums, (tuple, list)): + static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) + assert isinstance(static_broadcasted_argnums, (tuple, list)) + + # reduce function + if reduce_func is None: + reduce_func = jnp.concatenate + + # jit function + func.__call__ = _make_pmap(dyn_vars=dyn_vars, + rand_vars=rand_vars, + func=func, + axis_name=axis_name, + in_axes=in_axes, + out_axes=out_axes, + static_broadcasted_argnums=static_broadcasted_argnums, + devices=devices, + backend=backend, + axis_size=axis_size, + donate_argnums=donate_argnums, + global_arg_shapes=global_arg_shapes, + reduce_func=reduce_func) + return func + + else: + raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, ' + f'but we got {type(func)}.') diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 632c6d79..7b519590 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -132,65 +132,19 @@ def evaluate_dyn_vars_with_cache( return stack -def _partial_fun2( - fun: Callable, - args: tuple, - kwargs: dict, - static_argnums: Sequence[int] = (), - static_argnames: Sequence[str] = () -): - num_args = len(args) - - # arguments - static_args = dict() - dyn_args = [] - dyn_arg_ids = dict() - static_argnums = list(static_argnums) - dyn_i = 0 - for i in range(num_args): - if i in static_argnums: - static_argnums.remove(i) - static_args[i] = args[i] - else: - dyn_args.append(args[i]) - dyn_arg_ids[i] = dyn_i - dyn_i += 1 - if len(static_argnums) > 0: - raise ValueError(f"Invalid static_argnums: {static_argnums}") - - # keyword arguments - static_kwargs, dyn_kwargs = {}, {} - for k, arg in kwargs.items(): - if k in static_argnames: - static_kwargs[k] = arg - else: - dyn_kwargs[k] = arg - del args, kwargs, static_argnums, static_argnames - - @wraps(fun) - def new_fun(*dynargs, **dynkwargs): - return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], - **static_kwargs, - **dynkwargs) - - return new_fun, dyn_args, dyn_kwargs - - def eval_shape( fun: Callable, *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), - with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. Args: fun: The callable function. - *args: The positional arguments. - **kwargs: The keyword arguments. - with_stack: Whether evaluate the function within a local variable stack. + *args: + **kwargs: static_argnums: The static argument indices. static_argnames: The static argument names. @@ -199,30 +153,21 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) + f2, args, kwargs = _partial_fun(fun, args, kwargs, + static_argnums=static_argnums, + static_argnames=static_argnames) else: - f2 = fun + f2, args, kwargs = fun, args, kwargs # evaluate the function fun_in_eval_shape.append(fun) try: - if with_stack: + with jax.ensure_compile_time_eval(): with VariableStack() as stack: if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) + returns = fun(*args, **kwargs) else: - returns = jax.eval_shape(f2, *args, **kwargs) - else: - stack = None - if len(fun_in_eval_shape) > 1: - returns = f2(*args, **kwargs) - else: - returns = jax.eval_shape(f2, *args, **kwargs) + returns = jax.eval_shape(fun, *args, **kwargs) finally: fun_in_eval_shape.pop() - del f2 - if with_stack: - return stack, returns - else: - return returns - + return stack, returns diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index b7babae8..5014da0b 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,3 +1,4 @@ +from contextlib import contextmanager from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple import jax @@ -189,14 +190,6 @@ def remove_by_id(self, *ids, error_when_absent=False): remove_var_by_id = remove_by_id - @classmethod - def num_of_stack(self): - return len(var_stack_list) - - @classmethod - def is_first_stack(self): - return len(var_stack_list) == 0 - def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) @@ -217,6 +210,42 @@ def __add__(self, other: dict): var_stack_list: List[VariableStack] = [] +transform_stack: List[Callable] = [] + + +@contextmanager +def new_transform(transform: Any): + transform_stack.append(transform) + try: + yield + finally: + transform_stack.pop() + + +def outermost_stack(): + if len(var_stack_list): + return var_stack_list[0] + else: + return None + + +def outermost_transform(): + if len(transform_stack): + return transform_stack[0] + else: + return None + + +def current_transform_number(): + return len(transform_stack) + + +def _stack_add_read(var: 'Variable'): + pass + + +def _stack_add_write(var: 'Variable'): + pass @register_pytree_node_class diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index 01f77dbc..ed687eea 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -1,7 +1,8 @@ - -from .numba_approach import (CustomOpByNumba, - register_op_with_numba, - compile_cpu_signature_with_numba) -from .taichi_aot_based import clean_caches, check_kernels_count -from .base import XLACustomOp -from .utils import register_general_batching +from .numba_approach import (CustomOpByNumba, + register_op_with_numba, + compile_cpu_signature_with_numba) +from .base import XLACustomOp +from .utils import register_general_batching +from .taichi_aot_based import clean_caches, check_kernels_count +from .base import XLACustomOp +from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 1824ac91..ca070a19 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -4,8 +4,8 @@ import jax import numpy as np from jax.interpreters import xla, batching, ad, mlir -from numba.core.dispatcher import Dispatcher +from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject @@ -20,6 +20,8 @@ from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp +numba = import_numba(error_if_not_found=False) + __all__ = [ 'XLACustomOp', ] @@ -104,24 +106,30 @@ def __init__( self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) # cpu function + cpu_checked = False if cpu_kernel is None: - pass - elif isinstance(cpu_kernel, Dispatcher): # numba - register_numba_cpu_translation_rule(self.primitive, cpu_kernel) - elif hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi + cpu_checked = True + if numba is not None: # numba + from numba.core.dispatcher import Dispatcher + if isinstance(cpu_kernel, Dispatcher): + register_numba_cpu_translation_rule(self.primitive, cpu_kernel) + cpu_checked = True + if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi register_taichi_cpu_translation_rule(self.primitive, cpu_kernel) - else: + cpu_checked = True + if not cpu_checked: raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. ' f'But we got {cpu_kernel}') # gpu function + gpu_checked = False if gpu_kernel is None: - pass - elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi + gpu_checked = True + if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) - else: - raise ValueError(f'"cpu_kernel" must be a taichi kernel function. ' - f'But we got {gpu_kernel}') + gpu_checked = True + if not gpu_checked: + raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index cc2ce5b4..5bbd04e0 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -1,20 +1,22 @@ # -*- coding: utf-8 -*- -import warnings from functools import partial from typing import Callable from typing import Union, Sequence -import numba import jax from jax.interpreters import xla, batching, ad from jax.tree_util import tree_map -from numba.core.dispatcher import Dispatcher +from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject +from brainpy.errors import PackageMissingError from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba +numba = import_numba(error_if_not_found=False) + + __all__ = [ 'CustomOpByNumba', 'register_op_with_numba', @@ -137,6 +139,9 @@ def register_op_with_numba( f'For more information, please refer to the documentation: ' f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') + if numba is None: + raise PackageMissingError.by_purpose('numba', 'custom op with numba') + if out_shapes is None: raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' @@ -146,6 +151,7 @@ def register_op_with_numba( prim.multiple_results = multiple_results # user defined function + from numba.core.dispatcher import Dispatcher if not isinstance(cpu_func, Dispatcher): cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) @@ -196,5 +202,3 @@ def abs_eval_rule(*input_shapes, **info): ad.primitive_transposes[prim] = transpose_translation return prim - - diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 13974b5b..4b06effd 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -1,146 +1,152 @@ -# -*- coding: utf-8 -*- - -import ctypes - -from jax import dtypes, numpy as jnp -from jax.core import ShapedArray -from jax.lib import xla_client -from numba import types, carray, cfunc - -__all__ = [ - 'compile_cpu_signature_with_numba' -] - -ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, # void* pointer - ctypes.c_char_p, # const char *name - ctypes.c_void_p, # PyCapsule_Destructor destructor -] -ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - - -def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): - target_name, inputs, input_shapes, xla_output_shapes = \ - compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) - return xla_client.ops.CustomCallWithLayout( - c, - target_name, - operands=inputs, - operand_shapes_with_layout=input_shapes, - shape_with_layout=xla_output_shapes, - ) - - -def _cpu_signature( - func, - input_dtypes, - input_shapes, - output_dtypes, - output_shapes, - multiple_results: bool, - debug: bool = False -): - code_scope = dict( - func_to_call=func, - input_shapes=input_shapes, - input_dtypes=input_dtypes, - output_shapes=output_shapes, - output_dtypes=output_dtypes, - carray=carray, - ) - - # inputs - if len(input_shapes) > 1: - args_in = [ - f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' - for i in range(len(input_shapes)) - ] - args_in = '(\n ' + "\n ".join(args_in) + '\n )' - else: - args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' - - # outputs - if multiple_results: - args_out = [ - f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' - for i in range(len(output_shapes)) - ] - args_out = '(\n ' + "\n ".join(args_out) + '\n )' - else: - args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' - - # function body - code_string = ''' -def xla_cpu_custom_call_target(output_ptrs, input_ptrs): - args_out = {args_out} - args_in = {args_in} - func_to_call(args_out, args_in) - '''.format(args_in=args_in, - args_out=args_out) - if debug: print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - - new_f = code_scope['xla_cpu_custom_call_target'] - if multiple_results: - xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), - types.CPointer(types.voidptr)))(new_f) - else: - xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) - target_name = xla_c_rule.native_name.encode("ascii") - capsule = ctypes.pythonapi.PyCapsule_New( - xla_c_rule.address, # A CFFI pointer to a function - b"xla._CUSTOM_CALL_TARGET", # A binary string - None # PyCapsule object run at destruction - ) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - return target_name - - -def compile_cpu_signature_with_numba( - c, - func, - abs_eval_fn, - multiple_results, - inputs: tuple, - description: dict = None, -): - input_layouts = [c.get_shape(arg) for arg in inputs] - info_inputs = [] - if description is None: description = dict() - for v in description.values(): - if isinstance(v, (int, float)): - input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) - info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) - elif isinstance(v, (tuple, list)): - v = jnp.asarray(v) - input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) - info_inputs.append(xla_client.ops.Constant(c, v)) - else: - raise TypeError - input_layouts = tuple(input_layouts) - input_dtypes = tuple(shape.element_type() for shape in input_layouts) - input_dimensions = tuple(shape.dimensions() for shape in input_layouts) - output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_layouts[:len(inputs)]), - **description) - if isinstance(output_abstract_arrays, ShapedArray): - output_abstract_arrays = (output_abstract_arrays,) - assert not multiple_results - else: - assert multiple_results - output_shapes = tuple(array.shape for array in output_abstract_arrays) - output_dtypes = tuple(array.dtype for array in output_abstract_arrays) - output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) - target_name = _cpu_signature(func, - input_dtypes, - input_dimensions, - output_dtypes, - output_shapes, - multiple_results, - debug=False) - output_layouts = [xla_client.Shape.array_shape(*arg) - for arg in zip(output_dtypes, output_shapes, output_layouts)] - output_layouts = (xla_client.Shape.tuple_shape(output_layouts) - if multiple_results else - output_layouts[0]) - return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts +# -*- coding: utf-8 -*- + +import ctypes + +from jax import dtypes, numpy as jnp +from jax.core import ShapedArray +from jax.lib import xla_client + +from brainpy._src.dependency_check import import_numba + +numba = import_numba(error_if_not_found=False) +ctypes.pythonapi.PyCapsule_New.argtypes = [ + ctypes.c_void_p, # void* pointer + ctypes.c_char_p, # const char *name + ctypes.c_void_p, # PyCapsule_Destructor destructor +] +ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object + +__all__ = [ + '_cpu_translation', + 'compile_cpu_signature_with_numba', +] + +if numba is not None: + from numba import types, carray, cfunc + + +def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): + target_name, inputs, input_shapes, xla_output_shapes = \ + compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) + return xla_client.ops.CustomCallWithLayout( + c, + target_name, + operands=inputs, + operand_shapes_with_layout=input_shapes, + shape_with_layout=xla_output_shapes, + ) + + +def _cpu_signature( + func, + input_dtypes, + input_shapes, + output_dtypes, + output_shapes, + multiple_results: bool, + debug: bool = False +): + code_scope = dict( + func_to_call=func, + input_shapes=input_shapes, + input_dtypes=input_dtypes, + output_shapes=output_shapes, + output_dtypes=output_dtypes, + carray=carray, + ) + + # inputs + if len(input_shapes) > 1: + args_in = [ + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' + for i in range(len(input_shapes)) + ] + args_in = '(\n ' + "\n ".join(args_in) + '\n )' + else: + args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' + + # outputs + if multiple_results: + args_out = [ + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' + for i in range(len(output_shapes)) + ] + args_out = '(\n ' + "\n ".join(args_out) + '\n )' + else: + args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' + + # function body + code_string = ''' +def xla_cpu_custom_call_target(output_ptrs, input_ptrs): + args_out = {args_out} + args_in = {args_in} + func_to_call(args_out, args_in) + '''.format(args_in=args_in, + args_out=args_out) + if debug: print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + + new_f = code_scope['xla_cpu_custom_call_target'] + if multiple_results: + xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), + types.CPointer(types.voidptr)))(new_f) + else: + xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) + target_name = xla_c_rule.native_name.encode("ascii") + capsule = ctypes.pythonapi.PyCapsule_New( + xla_c_rule.address, # A CFFI pointer to a function + b"xla._CUSTOM_CALL_TARGET", # A binary string + None # PyCapsule object run at destruction + ) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + return target_name + + +def compile_cpu_signature_with_numba( + c, + func, + abs_eval_fn, + multiple_results, + inputs: tuple, + description: dict = None, +): + input_layouts = [c.get_shape(arg) for arg in inputs] + info_inputs = [] + if description is None: description = dict() + for v in description.values(): + if isinstance(v, (int, float)): + input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) + info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) + elif isinstance(v, (tuple, list)): + v = jnp.asarray(v) + input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) + info_inputs.append(xla_client.ops.Constant(c, v)) + else: + raise TypeError + input_layouts = tuple(input_layouts) + input_dtypes = tuple(shape.element_type() for shape in input_layouts) + input_dimensions = tuple(shape.dimensions() for shape in input_layouts) + output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) + for shape in input_layouts[:len(inputs)]), + **description) + if isinstance(output_abstract_arrays, ShapedArray): + output_abstract_arrays = (output_abstract_arrays,) + assert not multiple_results + else: + assert multiple_results + output_shapes = tuple(array.shape for array in output_abstract_arrays) + output_dtypes = tuple(array.dtype for array in output_abstract_arrays) + output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) + target_name = _cpu_signature(func, + input_dtypes, + input_dimensions, + output_dtypes, + output_shapes, + multiple_results, + debug=False) + output_layouts = [xla_client.Shape.array_shape(*arg) + for arg in zip(output_dtypes, output_shapes, output_layouts)] + output_layouts = (xla_client.Shape.tuple_shape(output_layouts) + if multiple_results else + output_layouts[0]) + return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index fb76aed2..f461f427 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -6,17 +6,20 @@ from jax.interpreters import xla, mlir from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call -from numba import types, carray, cfunc +from brainpy._src.dependency_check import import_numba +from brainpy.errors import PackageMissingError from .utils import _shape_to_layout +numba = import_numba(error_if_not_found=False) +if numba is not None: + from numba import types, carray, cfunc __all__ = [ 'register_numba_xla_cpu_translation_rule', 'register_numba_mlir_cpu_translation_rule', ] - # [void* pointer, # const char *name, # PyCapsule_Destructor destructor] @@ -104,6 +107,9 @@ def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs): def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False): + if numba is None: + raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') + # do not support after jax >= 0.4.24 xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule, cpu_kernel, @@ -168,5 +174,8 @@ def numba_cpu_custom_call_target(output_ptrs, input_ptrs): def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False): + if numba is None: + raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') + rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug) mlir.register_lowering(primitive, rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py index 24f010a1..2c9f0972 100644 --- a/brainpy/_src/math/op_register/tests/test_ad_support.py +++ b/brainpy/_src/math/op_register/tests/test_ad_support.py @@ -1,13 +1,18 @@ +import pytest from typing import Tuple import jax -import numba from jax import core from jax import numpy as jnp from jax.interpreters import ad import brainpy as bp import brainpy.math as bm +from brainpy._src.dependency_check import import_numba + +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py index 968155ef..dc093f62 100644 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ b/brainpy/_src/math/op_register/tests/test_numba_based.py @@ -1,6 +1,11 @@ +import pytest import jax.core import brainpy.math as bm -import numba + +from brainpy._src.dependency_check import import_numba +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 03023754..4db38fbc 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -1,9 +1,14 @@ +import pytest import jax import jax.numpy as jnp -import taichi as ti import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + bm.set_platform('cpu') diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 1bebcdaf..51c964b2 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -1,54 +1,58 @@ -import brainpy.math as bm -import jax -import jax.numpy as jnp -import platform -import pytest -import taichi - -if not platform.platform().startswith('Windows'): - pytest.skip(allow_module_level=True) - -@taichi.func -def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: - return weight[0] - - -@taichi.func -def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): - out[index] += weight_val - -@taichi.kernel -def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), - vector: taichi.types.ndarray(ndim=1), - weight: taichi.types.ndarray(ndim=1), - out: taichi.types.ndarray(ndim=1)): - weight_val = get_weight(weight) - num_rows, num_cols = indices.shape - taichi.loop_config(serialize=True) - for i in range(num_rows): - if vector[i]: - for j in range(num_cols): - update_output(out, indices[i, j], weight_val) - -prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) - -def test_taichi_clean_cache(): - s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) - vector = bm.random.rand(s) < 0.1 - weight = bm.array([1.0]) - - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - print(out) - bm.clear_buffer_memory() - - print('kernels: ', bm.check_kernels_count()) - - bm.clean_caches() - - print('kernels: ', bm.check_kernels_count()) - +import brainpy.math as bm +import jax +import jax.numpy as jnp +import platform +import pytest + +from brainpy._src.dependency_check import import_taichi +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + +if not platform.platform().startswith('Windows'): + pytest.skip(allow_module_level=True) + +@ti.func +def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: + return weight[0] + + +@ti.func +def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): + out[index] += weight_val + +@ti.kernel +def event_ell_cpu(indices: ti.types.ndarray(ndim=2), + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + weight_val = get_weight(weight) + num_rows, num_cols = indices.shape + ti.loop_config(serialize=True) + for i in range(num_rows): + if vector[i]: + for j in range(num_cols): + update_output(out, indices[i, j], weight_val) + +prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) + +def test_taichi_clean_cache(): + s = 1000 + indices = bm.random.randint(0, s, (s, 1000)) + vector = bm.random.rand(s) < 0.1 + weight = bm.array([1.0]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + print(out) + bm.clear_buffer_memory() + + print('kernels: ', bm.check_kernels_count()) + + bm.clean_caches() + + print('kernels: ', bm.check_kernels_count()) + # test_taichi_clean_cache() \ No newline at end of file diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d45f2c80..d5353324 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,8 +1,7 @@ - -from ._coo_mv import * +# from ._coo_mv import * +# from ._bsr_mv import * from ._csr_mv import * from ._utils import * -from ._bsr_mv import * from ._bsr_mm import * from ._jax_prim import * diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index 453ab387..19800749 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -1,22 +1,23 @@ # -*- coding: utf-8 -*- from functools import partial -from typing import Union, Tuple +from typing import Tuple import jax.lax -import numba import numpy as np from jax import numpy as jnp from jax.core import Primitive, ShapedArray from jax.interpreters import ad, xla from jax.lib import xla_client +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba from brainpy._src.math.interoperability import as_jax -from brainpy._src.dependency_check import import_brainpylib_gpu_ops from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) from brainpy.errors import GPUOperatorNotFound +numba = import_numba(error_if_not_found=False) + __all__ = [ 'bcsrmm', ] @@ -264,52 +265,53 @@ def bcsrmm( raise ValueError -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m) - res_val = outs[0] - # B_data: (num_block, block_size_k, block_size_n) - A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins - block_size_k = block_size_k[()] - block_size_n = block_size_n[()] - n_block = n // block_size_n - - for ni in numba.prange(n_block): - C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) - start, end = B_inptr[ni], B_inptr[ni + 1] - ns = ni * block_size_n - ne = ns + block_size_n - for i in range(start, end): - ki = B_indices[i, 0] - ks = ki * block_size_k - ke = ki + block_size_k - bi = B_indices[i, 1] - C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) - res_val[ns: ne] = C_tmp - return res_val - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m) - res_val = outs[0] - # B_data: (num_block, block_size_n, block_size_k) - A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins - block_size_k = block_size_k[()] - block_size_n = block_size_n[()] - n_block = n // block_size_n - - for ni in numba.prange(n_block): - C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) - start, end = B_inptr[ni], B_inptr[ni + 1] - ns = ni * block_size_n - ne = ns + block_size_n - for i in range(start, end): - ki = B_indices[i, 0] - ks = ki * block_size_k - ke = ki + block_size_k - bi = B_indices[i, 1] - C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) - res_val[ns: ne] = C_tmp - return res_val +if numba is not None: + @numba.njit(fastmath=True, parallel=True, nogil=True) + def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m) + res_val = outs[0] + # B_data: (num_block, block_size_k, block_size_n) + A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins + block_size_k = block_size_k[()] + block_size_n = block_size_n[()] + n_block = n // block_size_n + + for ni in numba.prange(n_block): + C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) + start, end = B_inptr[ni], B_inptr[ni + 1] + ns = ni * block_size_n + ne = ns + block_size_n + for i in range(start, end): + ki = B_indices[i, 0] + ks = ki * block_size_k + ke = ki + block_size_k + bi = B_indices[i, 1] + C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) + res_val[ns: ne] = C_tmp + return res_val + + + @numba.njit(fastmath=True, parallel=True, nogil=True) + def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m) + res_val = outs[0] + # B_data: (num_block, block_size_n, block_size_k) + A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins + block_size_k = block_size_k[()] + block_size_n = block_size_n[()] + n_block = n // block_size_n + + for ni in numba.prange(n_block): + C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) + start, end = B_inptr[ni], B_inptr[ni + 1] + ns = ni * block_size_n + ne = ns + block_size_n + for i in range(start, end): + ki = B_indices[i, 0] + ks = ki * block_size_k + ke = ki + block_size_k + bi = B_indices[i, 1] + C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) + res_val[ns: ne] = C_tmp + return res_val def _bcsrmm_cutlass_abstract( diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index 37759757..42969f43 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -1,28 +1,22 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Union, Tuple import jax -import numba -import numpy as np -from jax import core, dtypes from jax import numpy as jnp -from jax.interpreters import ad, mlir, xla -from jax.lib import xla_client -from jaxlib import gpu_sparse +from jax.experimental.sparse import csr +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, - XLACustomOp) +from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy.errors import GPUOperatorNotFound +from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'csrmv', @@ -37,7 +31,6 @@ def csrmv( *, shape: Tuple[int, int], transpose: bool = False, - method: str = None, ): """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. @@ -70,495 +63,6 @@ def csrmv( - ``vector``: - ``adaptive``: - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - if method is None: - return csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose) - else: - return csrmv_brainpylib(data, indices, indptr, vector, shape=shape, transpose=transpose, method=method) - - -### BRAINPYLIB ### - -def csrmv_brainpylib( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, - method: str = 'cusparse', -): - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - method: str - The method used to compute Matrix-Vector Multiplication. The candidate methods are: - - - ``cusparse``: using cuSPARSE library. - - ``scalar``: - - ``vector``: - - ``adaptive``: - - Returns - ------- - y : ndarry - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - - data = jnp.atleast_1d(as_jax(data)) - indices = as_jax(indices) - indptr = as_jax(indptr) - vector = as_jax(vector) - - if vector.dtype == jnp.bool_: - vector = as_jax(vector, dtype=data.dtype) - - if method == 'cusparse': - if jax.default_backend() == 'gpu': - if data.shape[0] == 1: - data = jnp.ones(indices.shape, dtype=data.dtype) * data - if indices.dtype in [jnp.uint32, jnp.uint64]: - indices = jnp.asarray(indices, dtype=dtypes.canonicalize_dtype(jnp.int64)) - if indptr.dtype in [jnp.uint32, jnp.uint64]: - indptr = jnp.asarray(indptr, dtype=dtypes.canonicalize_dtype(jnp.int64)) - return _csrmv_cusparse_p.bind(data, - indices, - indptr, - vector, - shape=shape, - transpose=transpose) - - elif method == 'adaptive': - return _csrmv_adaptive_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - elif method == 'scalar': - return _csrmv_scalar_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - elif method == 'vector': - return _csrmv_vector_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose) - - else: - raise ValueError(f'Only support methods: cusparse, scalar, vector, and adaptive. But we got {method}.') - - -def _csrmv_abstract(data, indices, indptr, vector, *, shape, transpose): - if data.dtype not in [jnp.float32, jnp.float64]: - raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.') - if data.dtype != vector.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {vector.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - out_shape = shape[1] if transpose else shape[0] - return core.ShapedArray((out_shape,), data.dtype) - - -@numba.njit(fastmath=True) -def _csr_matvec_transpose_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, col_indices, row_ptr, vector, shape, _ = ins - # (csr mat).T @ vec - - if values.shape[0] == 1: - values = values[0] - for row_i in range(shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += values * v - else: - for row_i in range(shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += v * values[j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _csr_matvec_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, col_indices, row_ptr, vector, shape, _ = ins - # csr mat @ vec - if values.shape[0] == 1: - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values * vector[col_indices[j]] - res_val[row_i] = r - else: - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - res_val[row_i] = r - - -def _csrmv_cpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - inputs = (data, indices, indptr, vector) - description = dict(shape=shape, transpose=transpose) - if transpose: - target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba( - c, - _csr_matvec_transpose_numba_imp, - _csrmv_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - else: - target_name, inputs, input_layouts, output_layouts = compile_cpu_signature_with_numba( - c, - _csr_matvec_numba_imp, - _csrmv_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, - target_name, - operands=inputs, - operand_shapes_with_layout=input_layouts, - shape_with_layout=output_layouts, - ) - - -def _csrmv_cusparse_gpu_lowering(ctx, data, indices, indptr, vector, *, shape, transpose): - data_aval, indices_aval, _, v_aval = ctx.avals_in - dtype = data_aval.dtype - if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: - raise TypeError(f"cusparse_csr_matvec cusparse/hipsparse lowering not available for dtype={dtype}. " - "Falling back to default implementation.") - return [gpu_sparse.cuda_csr_matvec(data, indices, indptr, vector, - shape=shape, - transpose=transpose, - data_dtype=dtype, - x_dtype=v_aval.dtype, - index_dtype=indices_aval.dtype)] - - -def _csrmv_jvp_mat(csr_prim, data_dot, data, indices, indptr, v, *, shape, transpose): - return csr_prim.bind(data_dot, indices, indptr, v, shape=shape, transpose=transpose) - - -def _csrmv_jvp_vec(prim, v_dot, data, indices, indptr, v, *, shape, transpose): - return prim.bind(data, indices, indptr, v_dot, shape=shape, transpose=transpose) - - -def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return data, indices, indptr, ad.Zero(vector) - else: - ct_vector = _csrmv_cusparse_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, ct_vector - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_cusparse_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_cusparse_p = core.Primitive('cusparse_csr_matvec') -_csrmv_cusparse_p.def_abstract_eval(_csrmv_abstract) -_csrmv_cusparse_p.def_impl(partial(xla.apply_primitive, _csrmv_cusparse_p)) -# xla.backend_specific_translations['cpu'][_csrmv_cusparse_p] = _csrmv_cpu_translation -ad.defjvp(_csrmv_cusparse_p, - partial(_csrmv_jvp_mat, _csrmv_cusparse_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_cusparse_p)) -ad.primitive_transposes[_csrmv_cusparse_p] = _csrmv_cusparse_transpose -register_general_batching(_csrmv_cusparse_p) -mlir.register_lowering(_csrmv_cusparse_p, _csrmv_cusparse_gpu_lowering, platform='cuda') - - -def _csr_matvec_scalar_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_scalar_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_scalar' + ftype + itype, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_scalar_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_scalar_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_scalar_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_scalar_p = core.Primitive('csr_matvec_scalar') -_csrmv_scalar_p.def_abstract_eval(_csrmv_abstract) -_csrmv_scalar_p.def_impl(partial(xla.apply_primitive, _csrmv_scalar_p)) -# xla.backend_specific_translations['cpu'][_csrmv_scalar_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_scalar_p] = _csr_matvec_scalar_gpu_translation -ad.defjvp(_csrmv_scalar_p, - partial(_csrmv_jvp_mat, _csrmv_scalar_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_scalar_p), ) -ad.primitive_transposes[_csrmv_scalar_p] = _csrmv_scalar_transpose -register_general_batching(_csrmv_scalar_p) - - -def _csr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_vector_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_vector' + ftype + itype, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_vector_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_vector_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_vector_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_vector_p = core.Primitive('csr_matvec_vector') -_csrmv_vector_p.def_abstract_eval(_csrmv_abstract) -_csrmv_vector_p.def_impl(partial(xla.apply_primitive, _csrmv_vector_p)) -# xla.backend_specific_translations['cpu'][_csrmv_vector_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_vector_p] = _csr_matvec_vector_gpu_translation -ad.defjvp(_csrmv_vector_p, - partial(_csrmv_jvp_mat, _csrmv_vector_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_vector_p), ) -ad.primitive_transposes[_csrmv_vector_p] = _csrmv_vector_transpose -register_general_batching(_csrmv_vector_p) - - -def _csr_matvec_adaptive_gpu_translation(c, data, indices, indptr, row_blocks, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_csrmv_adaptive_p.name) - if transpose: - raise NotImplementedError - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - ftype = b'_float' - elif data_shape.element_type() == np.float64: - ftype = b'_double' - else: - raise ValueError - indices_shape = c.get_shape(indices) - if indices_shape.element_type() == np.int32: - itype = b'_int' - elif indices_shape.element_type() == np.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'homo' if data_shape.dimensions() == (1,) else b'heter' - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - return xla_client.ops.CustomCallWithLayout( - c, - b'csrmv_' + data_name + b'_vector' + ftype + itype, - operands=(data, indices, indptr, row_blocks, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(row_blocks), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0],), (0,)), - opaque=opaque, - ) - - -def _csrmv_adaptive_transpose(ct, data, indices, indptr, vector, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(vector): - ct_vector = _csrmv_adaptive_p.bind(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = _csrmv_adaptive_p.bind(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -_csrmv_adaptive_p = core.Primitive('csr_matvec_adaptive') -_csrmv_adaptive_p.def_abstract_eval(_csrmv_abstract) -_csrmv_adaptive_p.def_impl(partial(xla.apply_primitive, _csrmv_adaptive_p)) -# xla.backend_specific_translations['cpu'][_csrmv_adaptive_p] = _csrmv_cpu_translation -# xla.backend_specific_translations['gpu'][_csrmv_adaptive_p] = _csr_matvec_adaptive_gpu_translation -ad.defjvp(_csrmv_adaptive_p, - partial(_csrmv_jvp_mat, _csrmv_adaptive_p), - None, - None, - partial(_csrmv_jvp_vec, _csrmv_adaptive_p), ) -ad.primitive_transposes[_csrmv_adaptive_p] = _csrmv_adaptive_transpose -register_general_batching(_csrmv_adaptive_p) - - -### TAICHI ### - -def csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -) -> jax.Array: - """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - vector: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple of int - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - Returns ------- y : ndarry @@ -593,171 +97,6 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] -# ------------- -# CPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += value * vector[row_i] - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += vector[row_i] * values[j] - - -@ti.kernel -def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += vector[col_indices[j]] - out[row_i] = r * value - - -@ti.kernel -def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - - -@ti.kernel -def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += value * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += vector[col_indices[j]] - j += 32 - out[row_i] += value * r - - -@ti.kernel -def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += values[j] * vector[row_i] - j += 32 - - -@ti.kernel -def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += values[j] * vector[col_indices[j]] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) - - -def _sparse_csr_matvec_transpose( - ct, data, indices, indptr, vector, *, outs, transpose, shape, -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(vector): - ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] - ct_data = jnp.inner(ct[0], ct_data) - else: - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] - - return ct_data, indices, indptr, vector - - def raw_csrmv_taichi( data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], @@ -767,17 +106,22 @@ def raw_csrmv_taichi( shape: Tuple[int, int], transpose: bool = False, ): + if ti is None: + raise PackageMissingError.by_purpose('taichi', purpose='customized operators') out_shape = shape[1] if transpose else shape[0] - if transpose: - if data.shape[0] == 1: - prim = _csr_matvec_transpose_homo_p + if data.shape[0] != 1: + if bm.get_platform() == 'gpu': + return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)] else: - prim = _csr_matvec_transpose_heter_p + if transpose: + prim = _csr_matvec_transpose_heter_p + else: + prim = _csr_matvec_heter_p else: - if data.shape[0] == 1: - prim = _csr_matvec_homo_p + if transpose: + prim = _csr_matvec_transpose_homo_p else: - prim = _csr_matvec_heter_p + prim = _csr_matvec_homo_p return prim(data, indices, @@ -788,25 +132,193 @@ def raw_csrmv_taichi( shape=shape) -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) - prim.def_transpose_rule(_sparse_csr_matvec_transpose) - return prim +if ti is not None: + + # ------------- + # CPU operators + # ------------- + @ti.kernel + def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += value * vector[row_i] + + + @ti.kernel + def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += vector[row_i] * values[j] + + + @ti.kernel + def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += vector[col_indices[j]] + out[row_i] = r * value + + + @ti.kernel + def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + out[row_i] = r + + + # ------------- + # GPU operators + # ------------- + + @ti.kernel + def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += value * vector[row_i] + j += 32 + + + @ti.kernel + def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += vector[col_indices[j]] + j += 32 + out[row_i] += value * r + + + @ti.kernel + def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + out[col_indices[j]] += values[j] * vector[row_i] + j += 32 + + + @ti.kernel + def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((row_ptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = row_ptr[row_i] + index + end_index = row_ptr[row_i + 1] + while j < end_index: + r += values[j] * vector[col_indices[j]] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) + + + def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): + return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) + + + def _sparse_csr_matvec_transpose( + ct, data, indices, indptr, vector, *, outs, transpose, shape, + ): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(vector): + ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] + ct_data = jnp.inner(ct[0], ct_data) + else: + row, col = csr_to_coo(indices, indptr) + ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] + + return ct_data, indices, indptr, vector + + + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) + prim.def_transpose_rule(_sparse_csr_matvec_transpose) + return prim + + # transpose homo + _csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) -# transpose homo -_csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) + # no transpose homo + _csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, + gpu_kernel=_sparse_csr_matvec_homo_gpu) -# no transpose homo -_csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, - gpu_kernel=_sparse_csr_matvec_homo_gpu) + # transpose heter + _csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, + gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) -# transpose heter -_csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) + # no transpose heter + _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, + gpu_kernel=_sparse_csr_matvec_heter_gpu) -# no transpose heter -_csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) + # heter cusparse + _csr_matvec_cusparse_p = csr.csr_matvec_p + register_general_batching(_csr_matvec_cusparse_p) diff --git a/brainpy/_src/math/sparse/_utils.py b/brainpy/_src/math/sparse/_utils.py index a1dc9190..f5b74e5e 100644 --- a/brainpy/_src/math/sparse/_utils.py +++ b/brainpy/_src/math/sparse/_utils.py @@ -3,9 +3,8 @@ import warnings from typing import Tuple -import jax import numpy as np -from jax import core, numpy as jnp, dtypes +from jax import core, numpy as jnp from jax.interpreters import mlir, ad from jaxlib import gpu_sparse diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 2c75f090..ec448e65 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -5,10 +5,14 @@ import jax from absl.testing import parameterized +import pytest import brainpy as bp import brainpy.math as bm -# bm.set_platform('gpu') +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) seed = 1234 diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_old.py b/brainpy/_src/math/sparse/tests/test_csrmv_old.py deleted file mode 100644 index b7321749..00000000 --- a/brainpy/_src/math/sparse/tests/test_csrmv_old.py +++ /dev/null @@ -1,352 +0,0 @@ -# -*- coding: utf-8 -*- - -from functools import partial - -import jax -import pytest -from absl.testing import parameterized -import platform -import brainpy as bp -import brainpy.math as bm - -pytest.skip('Old implementation.', allow_module_level=True) - -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - -cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') -scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') -vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') - - -class Test_cusparse_csrmv(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_cusparse_csrmv, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = bm.ones(indices.shape).value * homo_data - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r3 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r3)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - v=[-1., 0., 1.] - ) - def test_homo_vmap(self, transpose, shape, v): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = bm.ones((10, indices.shape[0])).value * v - homo_data = bm.ones(10).value * v - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - - r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(heter_data) - self.assertTrue(bm.allclose(r1, r2)) - - r3 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r3)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, transpose=transpose).sum(), - argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() - if transpose else - ((dense * a) @ vector).sum()), - argnums=0) - - r1 = csr_f1(homo_data) - r2 = dense_f1(homo_data) - self.assertTrue(bm.allclose(r1, r2)) - - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(homo_data, indices, indptr, v, - shape=shape, transpose=transpose).sum()) - dense_data = dense * homo_data - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) - - r3 = csr_f2(vector) - r4 = dense_f2(vector) - self.assertTrue(bm.allclose(r3, r4)) - - csr_f3 = jax.grad(lambda a, v: cusparse_csr_matvec(a, indices, indptr, v, - shape=shape, transpose=transpose).sum(), - argnums=(0, 1)) - dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() - if transpose else - ((dense * a) @ v).sum()), - argnums=(0, 1)) - - r5 = csr_f3(homo_data, vector) - r6 = dense_f3(homo_data, vector) - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - ) - def test_heter(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - r1 = cusparse_csr_matvec(heter_data, indices, indptr, vector, - shape=shape, transpose=transpose) - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r2 = (vector @ dense) if transpose else (dense @ vector) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_vmap(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - heter_data = rng.random((10, indices.shape[0])) - heter_data = bm.as_jax(heter_data) - dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=shape))(heter_data) - - f1 = partial(cusparse_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) - f2 = lambda a: (a.T @ vector) if transpose else (a @ vector) - - r1 = jax.vmap(f1)(heter_data) - r2 = jax.vmap(f2)(dense_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), - argnums=0) - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), - argnums=0) - - r1 = csr_f1(heter_data) - r2 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - r2 = r2[rows, cols] - self.assertTrue(bm.allclose(r1, r2)) - - csr_f2 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), - argnums=0) - r3 = csr_f2(vector) - r4 = dense_f2(vector) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - -class Test_csrmv(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmv, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - homo_data=[-1., 0., 0.1, 1.], - shape=[(100, 200), (10, 1000), (2, 2000)], - ) - def test_homo(self, shape, homo_data): - conn = bp.conn.FixedProb(0.1) - - # matrix - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # vector - rng = bm.random.RandomState(123) - vector = rng.random(shape[1]) - vector = bm.as_jax(vector) - - # csrmv - r1 = scalar_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - - heter_data = bm.ones(indices.shape).to_jax() * homo_data - r4 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r5 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r6 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - self.assertTrue(bm.allclose(r1, r4)) - self.assertTrue(bm.allclose(r1, r5)) - self.assertTrue(bm.allclose(r1, r6)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - rdense = dense @ vector - self.assertTrue(bm.allclose(r1, rdense)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(100, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = bm.as_jax(rng.random(indices.shape)) - vector = bm.as_jax(rng.random(shape[1])) - - r1 = scalar_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r2 = cusparse_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - r3 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - r4 = dense @ vector - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] - ) - def test_heter_grad(self, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - - indices, indptr = conn(*shape).require('pre2post') - heter_data = rng.random(indices.shape) - dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) - vector = rng.random(shape[1]) - - csr_f1 = jax.grad(lambda a: cusparse_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f2 = jax.grad(lambda a: scalar_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - csr_f3 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, shape=shape).sum()) - dense_f1 = jax.grad(lambda a: (a @ vector).sum()) - - r1 = csr_f1(heter_data) - r2 = csr_f2(heter_data) - r3 = csr_f3(heter_data) - - d1 = dense_f1(dense_data) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - d1 = d1[rows, cols] - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, r3)) - self.assertTrue(bm.allclose(r1, d1)) - - # csr_f4 = jax.grad(lambda v: cusparse_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f5 = jax.grad(lambda v: scalar_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # csr_f6 = jax.grad(lambda v: vector_csr_matvec(heter_data, indices, indptr, v, shape=shape).sum()) - # dense_f2 = jax.grad(lambda v: (dense_data @ v).sum()) - # r4 = csr_f4(vector) - # r5 = csr_f5(vector) - # r6 = csr_f6(vector) - # d2 = dense_f2(vector) - # self.assertTrue(bm.allclose(r4, r5)) - # self.assertTrue(bm.allclose(r4, r6)) - # self.assertTrue(bm.allclose(r4, d2)) - - bm.clear_buffer_memory() - - diff --git a/brainpy/_src/math/tests/test_tifunc.py b/brainpy/_src/math/tests/test_tifunc.py index 6823ebab..db6e7deb 100644 --- a/brainpy/_src/math/tests/test_tifunc.py +++ b/brainpy/_src/math/tests/test_tifunc.py @@ -1,122 +1,124 @@ -# -*- coding: utf-8 -*- - -import jax -import jax.numpy as jnp -import pytest - -pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.") -import brainpy.math as bm -import taichi as ti -import matplotlib.pyplot as plt -import os - - -bm.set_platform('cpu') - - -def test_taichi_random(): - @ti.kernel - def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), - out: ti.types.ndarray(ndim=1, dtype=ti.f32)): - key = bm.tifunc.lfsr88_key(seed[0]) - for i in range(out.shape[0]): - key, result = bm.tifunc.lfsr88_rand(key) - out[i] = result - - @ti.kernel - def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range(out.shape[0]): - out[i] = bm.tifunc.taichi_lcg_rand(seed) - - @ti.kernel - def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) - - @ti.kernel - def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), - low_high: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - low = low_high[0] - high = low_high[1] - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) - - @ti.kernel - def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), - mu_sigma: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - key = bm.tifunc.lfsr88_key(seed[0]) - mu = mu_sigma[0] - sigma = mu_sigma[1] - - for i in range(out.shape[0]): - key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) - - n = 100000 - seed = jnp.array([1234, ], dtype=jnp.uint32) - low_high = jnp.array([0, 10]) - mu_sigma = jnp.array([0, 1]) - - prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, - gpu_kernel=test_taichi_lfsr88) - - - prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, - gpu_kernel=test_taichi_lcg_rand) - prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, - gpu_kernel=test_taichi_uniform_int_distribution) - prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, - gpu_kernel=test_taichi_uniform_real_distribution) - prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, - gpu_kernel=test_taichi_normal_distribution) - - file_path = os.path.dirname(os.path.abspath(__file__)) - - out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LFSR88 random number generator") - plt.savefig(file_path + "/lfsr88.png") - plt.close() - - out = prim_lcg_rand(seed, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("LCG random number generator") - plt.savefig(file_path + "/lcg_rand.png") - plt.close() - - out = prim_uniform_int_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) - # show the distribution of out - plt.hist(out, bins=10) - plt.title("Uniform int distribution (0, 10)") - plt.savefig(file_path + "/uniform_int_distribution.png") - plt.close() - - out = prim_uniform_real_distribution(seed, low_high, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.hist(out, bins=100) - plt.title("Uniform real distribution (0, 10)") - plt.savefig(file_path + "/uniform_real_distribution.png") - plt.close() - - out = prim_normal_distribution(seed, mu_sigma, - outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) - # show the distribution of out - plt.title("Normal distribution mu=0, sigma=1") - plt.hist(out, bins=100) - plt.savefig(file_path + "/normal_distribution.png") - - -# TODO; test default types +# -*- coding: utf-8 -*- + +import jax +import jax.numpy as jnp +import pytest + +pytestmark = pytest.mark.skip(reason="Skipped due to MacOS limitation, manual execution required for testing.") +import brainpy.math as bm +import matplotlib.pyplot as plt +import os + +from brainpy._src.dependency_check import import_taichi + +ti = import_taichi(error_if_not_found=False) +if ti is None: + pytest.skip('no taichi', allow_module_level=True) + +bm.set_platform('cpu') + + +def test_taichi_random(): + @ti.kernel + def test_taichi_lfsr88(seed: ti.types.ndarray(ndim=1, dtype=ti.u32), + out: ti.types.ndarray(ndim=1, dtype=ti.f32)): + key = bm.tifunc.lfsr88_key(seed[0]) + for i in range(out.shape[0]): + key, result = bm.tifunc.lfsr88_rand(key) + out[i] = result + + @ti.kernel + def test_taichi_lcg_rand(seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range(out.shape[0]): + out[i] = bm.tifunc.taichi_lcg_rand(seed) + + @ti.kernel + def test_taichi_uniform_int_distribution(seed: ti.types.ndarray(ndim=1), + low_high: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + low = low_high[0] + high = low_high[1] + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_randint(key, low, high) + + @ti.kernel + def test_taichi_uniform_real_distribution(seed: ti.types.ndarray(ndim=1), + low_high: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + low = low_high[0] + high = low_high[1] + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_uniform(key, low, high) + + @ti.kernel + def test_taichi_normal_distribution(seed: ti.types.ndarray(ndim=1), + mu_sigma: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + key = bm.tifunc.lfsr88_key(seed[0]) + mu = mu_sigma[0] + sigma = mu_sigma[1] + + for i in range(out.shape[0]): + key, out[i] = bm.tifunc.lfsr88_normal(key, mu, sigma) + + n = 100000 + seed = jnp.array([1234, ], dtype=jnp.uint32) + low_high = jnp.array([0, 10]) + mu_sigma = jnp.array([0, 1]) + + prim_lfsr88 = bm.XLACustomOp(cpu_kernel=test_taichi_lfsr88, + gpu_kernel=test_taichi_lfsr88) + + prim_lcg_rand = bm.XLACustomOp(cpu_kernel=test_taichi_lcg_rand, + gpu_kernel=test_taichi_lcg_rand) + prim_uniform_int_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_int_distribution, + gpu_kernel=test_taichi_uniform_int_distribution) + prim_uniform_real_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_uniform_real_distribution, + gpu_kernel=test_taichi_uniform_real_distribution) + prim_normal_distribution = bm.XLACustomOp(cpu_kernel=test_taichi_normal_distribution, + gpu_kernel=test_taichi_normal_distribution) + + file_path = os.path.dirname(os.path.abspath(__file__)) + + out = prim_lfsr88(seed, outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("LFSR88 random number generator") + plt.savefig(file_path + "/lfsr88.png") + plt.close() + + out = prim_lcg_rand(seed, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("LCG random number generator") + plt.savefig(file_path + "/lcg_rand.png") + plt.close() + + out = prim_uniform_int_distribution(seed, low_high, + outs=[jax.ShapeDtypeStruct((n,), jnp.int32)]) + # show the distribution of out + plt.hist(out, bins=10) + plt.title("Uniform int distribution (0, 10)") + plt.savefig(file_path + "/uniform_int_distribution.png") + plt.close() + + out = prim_uniform_real_distribution(seed, low_high, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.hist(out, bins=100) + plt.title("Uniform real distribution (0, 10)") + plt.savefig(file_path + "/uniform_real_distribution.png") + plt.close() + + out = prim_normal_distribution(seed, mu_sigma, + outs=[jax.ShapeDtypeStruct((n,), jnp.float32)]) + # show the distribution of out + plt.title("Normal distribution mu=0, sigma=1") + plt.hist(out, bins=100) + plt.savefig(file_path + "/normal_distribution.png") + +# TODO; test default types diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py index a9ee39f4..9cfd39e1 100644 --- a/brainpy/_src/math/tifunc.py +++ b/brainpy/_src/math/tifunc.py @@ -1,7 +1,7 @@ -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, raise_taichi_not_found from . import defaults -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ # taichi function for other utilities @@ -16,349 +16,330 @@ 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', ] +if ti is not None: -@ti.func -def _lcg_rand(state: ti.types.ndarray(ndim=1)): - # LCG constants - state[0] = ti.u32(1664525) * state[0] + ti.u32(1013904223) - return state[0] + ############################################# + # Random Number Generator: LFSR88 algorithm # + ############################################# + @ti.func + def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): + """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). -@ti.func -def taichi_lcg_rand(seed: ti.types.ndarray(ndim=1)): - """ - Generate a random number using the Taichi LCG algorithm. + This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. - Parameters: - seed (ti.types.ndarray): The seed value for the random number generator. + Source: + https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c - Returns: - float: A random number between 0 and 1. - """ + /**** VERY IMPORTANT **** : + The initial seeds s1, s2, s3 MUST be larger than + 1, 7, and 15 respectively. + */ - return float(_lcg_rand(seed)) / ti.u32(2 ** 32 - 1) + Args: + seed: int. The seed value for the random number generator. + Returns: + ti.math.uvec4: The random key for the LFSR88 random number generator. + """ + return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) -############################################# -# Random Number Generator: LFSR88 algorithm # -############################################# + @ti.func + def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): + """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). -@ti.func -def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). + Args: + key: The state value for the random number generator. - This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. + Returns: + ti.math.uvec4: The next random key. + """ + b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) + s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b + b = ((key[1] << 2) ^ key[1]) >> 25 + s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b + b = ((key[2] << 3) ^ key[2]) >> 11 + s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b + return ti.math.uvec4(s1, s2, s3, b) - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3 MUST be larger than - 1, 7, and 15 respectively. - */ + @ti.func + def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): + """ + Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. - Args: - seed: int. The seed value for the random number generator. + Args: + key: The state value for the random number generator. + mu: The mean of the normal distribution. + sigma: The standard deviation of the normal distribution. + epsilon: The epsilon value to avoid log(0). + """ - Returns: - ti.math.uvec4: The random key for the LFSR88 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) + key, r = lfsr88_randn(key, epsilon) + return key, mu + sigma * r -@ti.func -def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). + @ti.func + def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): + """ + Generate a random number with the standard normal distribution using the LFSR88 algorithm. - Args: - key: The state value for the random number generator. + Args: + key: The state value for the random number generator. + epsilon: The epsilon value to avoid log(0). - Returns: - ti.math.uvec4: The next random key. - """ - b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) - s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b - b = ((key[1] << 2) ^ key[1]) >> 25 - s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b - b = ((key[2] << 3) ^ key[2]) >> 11 - s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b - return ti.math.uvec4(s1, s2, s3, b) + References: + Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + """ -@ti.func -def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. + key, u1 = lfsr88_rand(key) + key, u2 = lfsr88_rand(key) - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ + # Ensure state1 is not zero to avoid log(0) + u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - key, r = lfsr88_randn(key, epsilon) - return key, mu + sigma * r + # Normalize the uniform samples + mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) + # Box-Muller transform + # z1 = mag * ti.cos(2 * ti.math.pi * u2) + z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) -@ti.func -def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with the standard normal distribution using the LFSR88 algorithm. + return key, z2 - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + @ti.func + def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. - """ + Parameters: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) - key, u1 = lfsr88_rand(key) - key, u2 = lfsr88_rand(key) - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) + @ti.func + def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): + key = lfsr88_next_key(key) + return key, dtype(key[0] ^ key[1] ^ key[2]) - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + @ti.func + def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. - return key, z2 + Args: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + return key, ti.cast(r * (high - low) + low, defaults.ti_float) -@ti.func -def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. + @ti.func + def lfsr88_rand(key: ti.types.vector(4, ti.u32)): + """ + Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) + Args: + key: The state value used for random number generation. + """ + key = lfsr88_next_key(key) + return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) -@ti.func -def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): - key = lfsr88_next_key(key) - return key, dtype(key[0] ^ key[1] ^ key[2]) + ############################################## + # Random Number Generator: LFSR113 algorithm # + ############################################## + @ti.func + def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): + """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). -@ti.func -def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. + This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) + Source: + https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c + /**** VERY IMPORTANT **** : + The initial seeds s1, s2, s3, s4 MUST be larger than + 1, 7, 15, and 127 respectively. + */ -@ti.func -def lfsr88_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. + Args: + seed: int. The seed value for the random number generator. - Args: - key: The state value used for random number generation. - """ - key = lfsr88_next_key(key) - return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + Returns: + ti.math.uvec4: The random key for the LFSR113 random number generator. + """ + return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) -############################################## -# Random Number Generator: LFSR113 algorithm # -############################################## + @ti.func + def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): + """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). + Args: + key: The state value for the random number generator. -@ti.func -def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). + Returns: + ti.math.uvec4: The next random key. + """ + z1 = key[0] + z2 = key[1] + z3 = key[2] + z4 = key[3] + b = ((z1 << 6) ^ z1) >> 13 + z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) + b = ((z2 << 2) ^ z2) >> 27 + z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) + b = ((z3 << 13) ^ z3) >> 21 + z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) + b = ((z4 << 3) ^ z4) >> 12 + z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) + return ti.math.uvec4(z1, z2, z3, z4) - This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c + @ti.func + def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): + """ + Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3, s4 MUST be larger than - 1, 7, 15, and 127 respectively. - */ + Args: + key: The state value for the random number generator. + mu: The mean of the normal distribution. + sigma: The standard deviation of the normal distribution. + epsilon: The epsilon value to avoid log(0). + """ - Args: - seed: int. The seed value for the random number generator. + key, r = lfsr113_randn(key, epsilon) + return key, ti.cast(mu + sigma * r, defaults.ti_float) - Returns: - ti.math.uvec4: The random key for the LFSR113 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) + @ti.func + def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): + """ + Generate a random number with standard normal distribution using the LFSR113 algorithm. -@ti.func -def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). + Args: + key: The state value for the random number generator. + epsilon: The epsilon value to avoid log(0). - Args: - key: The state value for the random number generator. + References: + Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method - Returns: - ti.math.uvec4: The next random key. - """ - z1 = key[0] - z2 = key[1] - z3 = key[2] - z4 = key[3] - b = ((z1 << 6) ^ z1) >> 13 - z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) - b = ((z2 << 2) ^ z2) >> 27 - z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) - b = ((z3 << 13) ^ z3) >> 21 - z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) - b = ((z4 << 3) ^ z4) >> 12 - z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) - return ti.math.uvec4(z1, z2, z3, z4) + """ + key, u1 = lfsr113_rand(key) + key, u2 = lfsr113_rand(key) -@ti.func -def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. + # Ensure state1 is not zero to avoid log(0) + u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ + # Normalize the uniform samples + mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - key, r = lfsr113_randn(key, epsilon) - return key, ti.cast(mu + sigma * r, defaults.ti_float) + # Box-Muller transform + # z1 = mag * ti.cos(2 * ti.math.pi * u2) + z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + return key, z2 -@ti.func -def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with standard normal distribution using the LFSR113 algorithm. - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). + @ti.func + def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + Parameters: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr113_next_key(key) + return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) - """ - key, u1 = lfsr113_rand(key) - key, u2 = lfsr113_rand(key) + @ti.func + def lfsr113_randint(key: ti.types.vector(4, ti.u32)): + key = lfsr113_next_key(key) + return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) + @ti.func + def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + Args: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + return key, ti.cast(r * (high - low) + low, defaults.ti_float) + + + @ti.func + def lfsr113_rand(key: ti.types.vector(4, ti.u32)): + """ + Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. - return key, z2 + Args: + key: The state value used for random number generation. + """ + key = lfsr113_next_key(key) + return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) -@ti.func -def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. + ########################### + # Reductions: warp reduce # + ########################### - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr113_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) + @ti.func + def warp_reduce_sum_all(val): + """ + Warp reduce sum. + Args: + val (float): The value to be reduced. -@ti.func -def lfsr113_randint(key: ti.types.vector(4, ti.u32)): - key = lfsr113_next_key(key) - return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) + Returns: + float: The reduced value. + """ + for i in ti.static(range(1, 32)): + val += ti.static(ti.simt.warp.shfl_xor(val, i)) + return val -@ti.func -def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. + @ti.func + def warp_reduce_sum(val): + """ + Warp reduce sum. - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) - - -@ti.func -def lfsr113_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. + Args: + val (float): The value to be reduced. - Args: - key: The state value used for random number generation. - """ - key = lfsr113_next_key(key) - return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + Returns: + float: The reduced value. + """ + for offset in ti.static((16, 8, 4, 2, 1)): + val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) + return val -########################### -# Reductions: warp reduce # -########################### - - -@ti.func -def warp_reduce_sum_all(val): - """ - Warp reduce sum. - - Args: - val (float): The value to be reduced. - - Returns: - float: The reduced value. - """ - for i in ti.static(range(1, 32)): - val += ti.static(ti.simt.warp.shfl_xor(val, i)) - return val - - -@ti.func -def warp_reduce_sum(val): - """ - Warp reduce sum. - - Args: - val (float): The value to be reduced. - - Returns: - float: The reduced value. - """ - for offset in ti.static((16, 8, 4, 2, 1)): - val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) - return val +else: + for func in __all__: + globals()[func] = raise_taichi_not_found \ No newline at end of file diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py index dd6865e6..6f2411ee 100644 --- a/brainpy/_src/tests/test_dyn_runner.py +++ b/brainpy/_src/tests/test_dyn_runner.py @@ -1,134 +1,133 @@ -# -*- coding: utf-8 -*- - - -import unittest -import brainpy as bp -import brainpy.math as bm - - -class TestDSRunner(unittest.TestCase): - def test1(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self): - self.i += 1 - - ds = ExampleDS() - runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_t_and_dt(self): - class ExampleDS(bp.DynamicalSystem): - def __init__(self): - super(ExampleDS, self).__init__() - self.i = bm.Variable(bm.zeros(1)) - - def update(self): - self.i += 1 * bp.share['dt'] - - runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) - runner.run(100.) - - def test_DSView(self): - class EINet(bp.Network): - def __init__(self, scale=1.0, method='exp_auto'): - super(EINet, self).__init__() - - # network size - num_exc = int(800 * scale) - num_inh = int(200 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - self.E = bp.neurons.LIF(num_exc, **pars, method=method) - self.I = bp.neurons.LIF(num_inh, **pars, method=method) - self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. - self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., method=method) - self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), - output=bp.synouts.COBA(E=-80.), g_max=wi, - tau=10., method=method) - - bm.random.seed() - - net = EINet(scale=1., method='exp_auto') - # with JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) - - # without JIT - runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, - inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) - - - -class TestMemoryEfficient(unittest.TestCase): - pass - - - - - - -# class TestMonitor(TestCase): -# def test_1d_array(self): -# try1 = TryGroup(monitors=['a']) -# try1.a = np.ones(1) -# try1.run(100.) -# -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 1 -# assert np.allclose(np.arange(2, 1002).reshape((-1, 1)), try1.mon.a) -# -# def test_2d_array(): -# set(dt=0.1) -# try1 = TryGroup(monitors=['a']) -# try1.a = np.ones((2, 2)) -# try1.run(100.) -# -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 -# series = np.arange(2, 1002).reshape((-1, 1)) -# series = np.repeat(series, 4, axis=1) -# assert np.allclose(series, try1.mon.a) -# -# def test_monitor_with_every(): -# set(dt=0.1) -# -# # try1: 2d array -# try1 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try1.run(100.) -# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# series = np.repeat(series, 4, axis=1) -# assert np.allclose(series, try1.mon.a) -# -# # try2: 1d array -# try2 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try2.a = np.array([1., 1.]) -# try2.run(100.) -# assert np.ndim(try2.mon.a) == 2 and np.shape(try2.mon.a)[1] == 2 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# series = np.repeat(series, 2, axis=1) -# assert np.allclose(series, try2.mon.a) -# -# # try2: scalar -# try3 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) -# try3.a = 1. -# try3.run(100.) -# assert np.ndim(try3.mon.a) == 2 and np.shape(try3.mon.a)[1] == 1 -# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) -# assert np.allclose(series, try3.mon.a) +# -*- coding: utf-8 -*- + +import pytest +import unittest +import brainpy as bp +import brainpy.math as bm + +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) + + +class TestDSRunner(unittest.TestCase): + def test1(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1 + + ds = ExampleDS() + runner = bp.DSRunner(ds, dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_t_and_dt(self): + class ExampleDS(bp.DynamicalSystem): + def __init__(self): + super(ExampleDS, self).__init__() + self.i = bm.Variable(bm.zeros(1)) + + def update(self): + self.i += 1 * bp.share['dt'] + + runner = bp.DSRunner(ExampleDS(), dt=1., monitors=['i'], progress_bar=False) + runner.run(100.) + + def test_DSView(self): + class EINet(bp.Network): + def __init__(self, scale=1.0, method='exp_auto'): + super(EINet, self).__init__() + + # network size + num_exc = int(800 * scale) + num_inh = int(200 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) + self.E = bp.neurons.LIF(num_exc, **pars, method=method) + self.I = bp.neurons.LIF(num_inh, **pars, method=method) + self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. + self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.synapses.Exponential(self.E, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.E2I = bp.synapses.Exponential(self.E, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., method=method) + self.I2E = bp.synapses.Exponential(self.I, self.E[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + self.I2I = bp.synapses.Exponential(self.I, self.I[:100], bp.conn.FixedProb(0.02), + output=bp.synouts.COBA(E=-80.), g_max=wi, + tau=10., method=method) + + bm.random.seed() + + net = EINet(scale=1., method='exp_auto') + # with JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)]).run(1.) + + # without JIT + runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}, + inputs=[(net.E.input, 20.), (net.I.input, 20.)], jit=False).run(0.2) + + +class TestMemoryEfficient(unittest.TestCase): + pass + +# class TestMonitor(TestCase): +# def test_1d_array(self): +# try1 = TryGroup(monitors=['a']) +# try1.a = np.ones(1) +# try1.run(100.) +# +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 1 +# assert np.allclose(np.arange(2, 1002).reshape((-1, 1)), try1.mon.a) +# +# def test_2d_array(): +# set(dt=0.1) +# try1 = TryGroup(monitors=['a']) +# try1.a = np.ones((2, 2)) +# try1.run(100.) +# +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 +# series = np.arange(2, 1002).reshape((-1, 1)) +# series = np.repeat(series, 4, axis=1) +# assert np.allclose(series, try1.mon.a) +# +# def test_monitor_with_every(): +# set(dt=0.1) +# +# # try1: 2d array +# try1 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try1.run(100.) +# assert np.ndim(try1.mon.a) == 2 and np.shape(try1.mon.a)[1] == 4 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# series = np.repeat(series, 4, axis=1) +# assert np.allclose(series, try1.mon.a) +# +# # try2: 1d array +# try2 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try2.a = np.array([1., 1.]) +# try2.run(100.) +# assert np.ndim(try2.mon.a) == 2 and np.shape(try2.mon.a)[1] == 2 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# series = np.repeat(series, 2, axis=1) +# assert np.allclose(series, try2.mon.a) +# +# # try2: scalar +# try3 = TryGroup(monitors=Monitor(variables=['a'], every=[1.])) +# try3.a = 1. +# try3.run(100.) +# assert np.ndim(try3.mon.a) == 2 and np.shape(try3.mon.a)[1] == 1 +# series = np.arange(2, 1002, 1. / 0.1).reshape((-1, 1)) +# assert np.allclose(series, try3.mon.a) diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py deleted file mode 100644 index cbc710db..00000000 --- a/brainpy/_src/tools/functions.py +++ /dev/null @@ -1,192 +0,0 @@ -import inspect -from functools import partial -from operator import attrgetter -from types import MethodType - -__all__ = [ - 'compose', 'pipe' -] - - -def identity(x): - """ Identity function. Return x - - >>> identity(3) - 3 - """ - return x - - -def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): - """ Like @property, but returns ``classval`` when used as a class attribute - - >>> class MyClass(object): - ... '''The class docstring''' - ... @instanceproperty(classval=__doc__) - ... def __doc__(self): - ... return 'An object docstring' - ... @instanceproperty - ... def val(self): - ... return 42 - ... - >>> MyClass.__doc__ - 'The class docstring' - >>> MyClass.val is None - True - >>> obj = MyClass() - >>> obj.__doc__ - 'An object docstring' - >>> obj.val - 42 - """ - if fget is None: - return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, - classval=classval) - return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, - classval=classval) - - -class InstanceProperty(property): - """ Like @property, but returns ``classval`` when used as a class attribute - - Should not be used directly. Use ``instanceproperty`` instead. - """ - - def __init__(self, fget=None, fset=None, fdel=None, doc=None, - classval=None): - self.classval = classval - property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) - - def __get__(self, obj, type=None): - if obj is None: - return self.classval - return property.__get__(self, obj, type) - - def __reduce__(self): - state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) - return InstanceProperty, state - - -class Compose(object): - """ A composition of functions - - See Also: - compose - """ - __slots__ = 'first', 'funcs' - - def __init__(self, funcs): - funcs = tuple(reversed(funcs)) - self.first = funcs[0] - self.funcs = funcs[1:] - - def __call__(self, *args, **kwargs): - ret = self.first(*args, **kwargs) - for f in self.funcs: - ret = f(ret) - return ret - - def __getstate__(self): - return self.first, self.funcs - - def __setstate__(self, state): - self.first, self.funcs = state - - @instanceproperty(classval=__doc__) - def __doc__(self): - def composed_doc(*fs): - """Generate a docstring for the composition of fs. - """ - if not fs: - # Argument name for the docstring. - return '*args, **kwargs' - - return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) - - try: - return ( - 'lambda *args, **kwargs: ' + - composed_doc(*reversed((self.first,) + self.funcs)) - ) - except AttributeError: - # One of our callables does not have a `__name__`, whatever. - return 'A composition of functions' - - @property - def __name__(self): - try: - return '_of_'.join( - (f.__name__ for f in reversed((self.first,) + self.funcs)) - ) - except AttributeError: - return type(self).__name__ - - def __repr__(self): - return '{.__class__.__name__}{!r}'.format( - self, tuple(reversed((self.first,) + self.funcs))) - - def __eq__(self, other): - if isinstance(other, Compose): - return other.first == self.first and other.funcs == self.funcs - return NotImplemented - - def __ne__(self, other): - equality = self.__eq__(other) - return NotImplemented if equality is NotImplemented else not equality - - def __hash__(self): - return hash(self.first) ^ hash(self.funcs) - - # Mimic the descriptor behavior of python functions. - # i.e. let Compose be called as a method when bound to a class. - # adapted from - # docs.python.org/3/howto/descriptor.html#functions-and-methods - def __get__(self, obj, objtype=None): - return self if obj is None else MethodType(self, obj) - - # introspection with Signature is only possible from py3.3+ - @instanceproperty - def __signature__(self): - base = inspect.signature(self.first) - last = inspect.signature(self.funcs[-1]) - return base.replace(return_annotation=last.return_annotation) - - __wrapped__ = instanceproperty(attrgetter('first')) - - -def compose(*funcs): - """ Compose functions to operate in series. - - Returns a function that applies other functions in sequence. - - Functions are applied from right to left so that - ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. - - If no arguments are provided, the identity function (f(x) = x) is returned. - - >>> inc = lambda i: i + 1 - >>> compose(str, inc)(3) - '4' - """ - if not funcs: - return identity - if len(funcs) == 1: - return funcs[0] - else: - return Compose(funcs) - - -def pipe(*funcs): - """ Pipe a value through a sequence of functions - - I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` - - We think of the value as progressing through a pipe of several - transformations, much like pipes in UNIX - - - >>> double = lambda i: 2 * i - >>> pipe(double, str)(3) - '6' - """ - return compose(*reversed(funcs)) diff --git a/brainpy/_src/tools/progress.py b/brainpy/_src/tools/progress.py new file mode 100644 index 00000000..13b6a157 --- /dev/null +++ b/brainpy/_src/tools/progress.py @@ -0,0 +1,519 @@ +"""Python utilities required by Keras.""" + +import binascii +import codecs +import importlib +import marshal +import os +import re +import sys +import time +import types as python_types + +import numpy as np + + +# isort: off + + +def func_dump(func): + """Serializes a user defined function. + + Args: + func: the function to serialize. + + Returns: + A tuple `(code, defaults, closure)`. + """ + if os.name == "nt": + raw_code = marshal.dumps(func.__code__).replace(b"\\", b"/") + code = codecs.encode(raw_code, "base64").decode("ascii") + else: + raw_code = marshal.dumps(func.__code__) + code = codecs.encode(raw_code, "base64").decode("ascii") + defaults = func.__defaults__ + if func.__closure__: + closure = tuple(c.cell_contents for c in func.__closure__) + else: + closure = None + return code, defaults, closure + + +def func_load(code, defaults=None, closure=None, globs=None): + """Deserializes a user defined function. + + Args: + code: bytecode of the function. + defaults: defaults of the function. + closure: closure of the function. + globs: dictionary of global objects. + + Returns: + A function object. + """ + if isinstance(code, (tuple, list)): # unpack previous dump + code, defaults, closure = code + if isinstance(defaults, list): + defaults = tuple(defaults) + + def ensure_value_to_cell(value): + """Ensures that a value is converted to a python cell object. + + Args: + value: Any value that needs to be casted to the cell type + + Returns: + A value wrapped as a cell object (see function "func_load") + """ + + def dummy_fn(): + value # just access it so it gets captured in .__closure__ + + cell_value = dummy_fn.__closure__[0] + if not isinstance(value, type(cell_value)): + return cell_value + return value + + if closure is not None: + closure = tuple(ensure_value_to_cell(_) for _ in closure) + try: + raw_code = codecs.decode(code.encode("ascii"), "base64") + except (UnicodeEncodeError, binascii.Error): + raw_code = code.encode("raw_unicode_escape") + code = marshal.loads(raw_code) + if globs is None: + globs = globals() + return python_types.FunctionType( + code, globs, name=code.co_name, argdefs=defaults, closure=closure + ) + + +class Progbar: + """Displays a progress bar. + + Args: + target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that should *not* + be averaged over time. Metrics in this list will be displayed as-is. + All others will be averaged by the progbar before display. + interval: Minimum visual progress update interval (in seconds). + unit_name: Display name for step counts (usually "step" or "sample"). + """ + + def __init__( + self, + target, + width=30, + verbose=1, + interval=0.05, + stateful_metrics=None, + unit_name="step", + ): + self.target = target + self.width = width + self.verbose = verbose + self.interval = interval + self.unit_name = unit_name + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + + self._dynamic_display = ( + (hasattr(sys.stdout, "isatty") and sys.stdout.isatty()) + or "ipykernel" in sys.modules + or "posix" in sys.modules + or "PYCHARM_HOSTED" in os.environ + ) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + self._time_at_epoch_start = self._start + self._time_at_epoch_end = None + self._time_after_first_step = None + + def update(self, current, values=None, finalize=None): + """Updates the progress bar. + + Args: + current: Index of current step. + values: List of tuples: `(name, value_for_last_step)`. If `name` is + in `stateful_metrics`, `value_for_last_step` will be displayed + as-is. Else, an average of the metric over time will be + displayed. + finalize: Whether this is the last update for the progress bar. If + `None`, uses `current >= self.target`. Defaults to `None`. + """ + if finalize is None: + if self.target is None: + finalize = False + else: + finalize = current >= self.target + + values = values or [] + for k, v in values: + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + # In the case that progress bar doesn't have a target value in + # the first epoch, both on_batch_end and on_epoch_end will be + # called, which will cause 'current' and 'self._seen_so_far' to + # have the same value. Force the minimal value to 1 here, + # otherwise stateful_metric will be 0s. + value_base = max(current - self._seen_so_far, 1) + if k not in self._values: + self._values[k] = [v * value_base, value_base] + else: + self._values[k][0] += v * value_base + self._values[k][1] += value_base + else: + # Stateful metrics output a numeric value. This representation + # means "take an average from a single value" but keeps the + # numeric formatting. + self._values[k] = [v, 1] + self._seen_so_far = current + + message = "" + now = time.time() + info = f" - {now - self._start:.0f}s" + if current == self.target: + self._time_at_epoch_end = now + if self.verbose == 1: + if now - self._last_update < self.interval and not finalize: + return + + prev_total_width = self._total_width + if self._dynamic_display: + message += "\b" * prev_total_width + message += "\r" + else: + message += "\n" + + if self.target is not None: + numdigits = int(np.log10(self.target)) + 1 + bar = ("%" + str(numdigits) + "d/%d [") % (current, self.target) + prog = float(current) / self.target + prog_width = int(self.width * prog) + if prog_width > 0: + bar += "=" * (prog_width - 1) + if current < self.target: + bar += ">" + else: + bar += "=" + bar += "." * (self.width - prog_width) + bar += "]" + else: + bar = "%7d/Unknown" % current + + self._total_width = len(bar) + message += bar + + time_per_unit = self._estimate_step_duration(current, now) + + if self.target is None or finalize: + info += self._format_time(time_per_unit, self.unit_name) + else: + eta = time_per_unit * (self.target - current) + if eta > 3600: + eta_format = "%d:%02d:%02d" % ( + eta // 3600, + (eta % 3600) // 60, + eta % 60, + ) + elif eta > 60: + eta_format = "%d:%02d" % (eta // 60, eta % 60) + else: + eta_format = "%ds" % eta + + info = f" - ETA: {eta_format}" + + for k in self._values_order: + info += f" - {k}:" + if isinstance(self._values[k], list): + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if abs(avg) > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + else: + info += f" {self._values[k]}" + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += " " * (prev_total_width - self._total_width) + + if finalize: + info += "\n" + + message += info + print_msg(message, line_break=False) + message = "" + + elif self.verbose == 2: + if finalize: + numdigits = int(np.log10(self.target)) + 1 + count = ("%" + str(numdigits) + "d/%d") % (current, self.target) + info = count + info + for k in self._values_order: + info += f" - {k}:" + avg = np.mean( + self._values[k][0] / max(1, self._values[k][1]) + ) + if avg > 1e-3: + info += f" {avg:.4f}" + else: + info += f" {avg:.4e}" + if self._time_at_epoch_end: + time_per_epoch = ( + self._time_at_epoch_end - self._time_at_epoch_start + ) + avg_time_per_step = time_per_epoch / self.target + self._time_at_epoch_start = now + self._time_at_epoch_end = None + info += " -" + self._format_time(time_per_epoch, "epoch") + info += " -" + self._format_time( + avg_time_per_step, self.unit_name + ) + info += "\n" + message += info + print_msg(message, line_break=False) + message = "" + + self._last_update = now + + def add(self, n, values=None): + self.update(self._seen_so_far + n, values) + + def _format_time(self, time_per_unit, unit_name): + """format a given duration to display to the user. + + Given the duration, this function formats it in either milliseconds + or seconds and displays the unit (i.e. ms/step or s/epoch) + Args: + time_per_unit: the duration to display + unit_name: the name of the unit to display + Returns: + a string with the correctly formatted duration and units + """ + formatted = "" + if time_per_unit >= 1 or time_per_unit == 0: + formatted += f" {time_per_unit:.0f}s/{unit_name}" + elif time_per_unit >= 1e-3: + formatted += f" {time_per_unit * 1000.0:.0f}ms/{unit_name}" + else: + formatted += f" {time_per_unit * 1000000.0:.0f}us/{unit_name}" + return formatted + + def _estimate_step_duration(self, current, now): + """Estimate the duration of a single step. + + Given the step number `current` and the corresponding time `now` this + function returns an estimate for how long a single step takes. If this + is called before one step has been completed (i.e. `current == 0`) then + zero is given as an estimate. The duration estimate ignores the duration + of the (assumed to be non-representative) first step for estimates when + more steps are available (i.e. `current>1`). + + Args: + current: Index of current step. + now: The current time. + + Returns: Estimate of the duration of a single step. + """ + if current: + # there are a few special scenarios here: + # 1) somebody is calling the progress bar without ever supplying + # step 1 + # 2) somebody is calling the progress bar and supplies step one + # multiple times, e.g. as part of a finalizing call + # in these cases, we just fall back to the simple calculation + if self._time_after_first_step is not None and current > 1: + time_per_unit = (now - self._time_after_first_step) / ( + current - 1 + ) + else: + time_per_unit = (now - self._start) / current + + if current == 1: + self._time_after_first_step = now + return time_per_unit + else: + return 0 + + def _update_stateful_metrics(self, stateful_metrics): + self.stateful_metrics = self.stateful_metrics.union(stateful_metrics) + + +def make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). + + Args: + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + + Returns: + A list of tuples of array indices. + """ + num_batches = int(np.ceil(size / float(batch_size))) + return [ + (i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches) + ] + + +def slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. + + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + + Can also work on list/array of indices: `slice_arrays(x, indices)` + + Args: + arrays: Single array or list of arrays. + start: can be an integer index (start index) or a list/array of indices + stop: integer (stop index); should be None if `start` was a list. + + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. + """ + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError( + "The stop argument has to be None if the value of start " + f"is a list. Received start={start}, stop={stop}" + ) + elif isinstance(arrays, list): + if hasattr(start, "__len__"): + # hdf5 datasets only support list objects as indices + if hasattr(start, "shape"): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + return [ + None + if x is None + else None + if not hasattr(x, "__getitem__") + else x[start:stop] + for x in arrays + ] + else: + if hasattr(start, "__len__"): + if hasattr(start, "shape"): + start = start.tolist() + return arrays[start] + if hasattr(start, "__getitem__"): + return arrays[start:stop] + return [None] + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Args: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] + + +def to_snake_case(name): + intermediate = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) + insecure = re.sub("([a-z])([A-Z])", r"\1_\2", intermediate).lower() + # If the class is private the name starts with "_" which is not secure + # for creating scopes. We prefix the name with "private" in this case. + if insecure[0] != "_": + return insecure + return "private" + insecure + + +def check_for_unexpected_keys(name, input_dict, expected_values): + unknown = set(input_dict.keys()).difference(expected_values) + if unknown: + raise ValueError( + f"Unknown entries in {name} dictionary: {list(unknown)}. " + f"Only expected following keys: {expected_values}" + ) + + +def validate_kwargs( + kwargs, allowed_kwargs, error_message="Keyword argument not understood:" +): + """Checks that all keyword arguments are in the set of allowed keys.""" + for kwarg in kwargs: + if kwarg not in allowed_kwargs: + raise TypeError(error_message, kwarg) + + +def default(method): + """Decorates a method to detect overrides in subclasses.""" + method._is_default = True + return method + + +def is_default(method): + """Check if a method is decorated with the `default` wrapper.""" + return getattr(method, "_is_default", False) + + +def populate_dict_with_module_objects(target_dict, modules, obj_filter): + for module in modules: + for name in dir(module): + obj = getattr(module, name) + if obj_filter(obj): + target_dict[name] = obj + + +class LazyLoader(python_types.ModuleType): + """Lazily import a module, mainly to avoid pulling in large dependencies.""" + + def __init__(self, local_name, parent_module_globals, name): + self._local_name = local_name + self._parent_module_globals = parent_module_globals + super().__init__(name) + + def _load(self): + """Load the module and insert it into the parent's globals.""" + # Import the target module and insert it into the parent's namespace + module = importlib.import_module(self.__name__) + self._parent_module_globals[self._local_name] = module + # Update this object's dict so that if someone keeps a reference to the + # LazyLoader, lookups are efficient (__getattr__ is only called on + # lookups that fail). + self.__dict__.update(module.__dict__) + return module + + def __getattr__(self, item): + module = self._load() + return getattr(module, item) + + +def print_msg(message, line_break=True): + """Print the message to absl logging or stdout.""" + if line_break: + sys.stdout.write(message + "\n") + else: + sys.stdout.write(message) + sys.stdout.flush() diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py deleted file mode 100644 index c285e561..00000000 --- a/brainpy/_src/tools/tests/test_functions.py +++ /dev/null @@ -1,24 +0,0 @@ - -import unittest - -import brainpy as bp -import brainpy.math as bm - - -class TestFunction(unittest.TestCase): - def test_compose(self): - f = lambda a: a + 1 - g = lambda a: a * 10 - fun1 = bp.tools.compose(f, g) - fun2 = bp.tools.pipe(g, f) - - arr = bm.random.randn(10) - r1 = fun1(arr) - r2 = fun2(arr) - groundtruth = f(g(arr)) - self.assertTrue(bm.allclose(r1, r2)) - self.assertTrue(bm.allclose(r1, groundtruth)) - bm.clear_buffer_memory() - - - diff --git a/brainpy/errors.py b/brainpy/errors.py index e59bb326..453c9c81 100644 --- a/brainpy/errors.py +++ b/brainpy/errors.py @@ -38,7 +38,12 @@ class AnalyzerError(BrainPyError): class PackageMissingError(BrainPyError): """The package missing error. """ - pass + + @classmethod + def by_purpose(cls, name, purpose): + err = (f'"{name}" must be installed when the user wants to use {purpose}. \n' + f'Please install through "pip install {name}".') + return cls(err) class BackendNotInstalled(BrainPyError): @@ -236,9 +241,5 @@ def __init__(self, name): ''') - - class SharedArgError(BrainPyError): pass - - diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 02f67134..9a64f9f2 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -1,103 +1,102 @@ -# -*- coding: utf-8 -*- - -# data structure -from .ndarray import * -from .delayvars import * -from .interoperability import * -from .datatypes import * -from .compat_numpy import * -from .compat_tensorflow import * -from .compat_pytorch import * -from .einops import * - -# functions -from .activations import * -from . import activations - -# operators -from .pre_syn_post import * -from .op_register import * -from . import surrogate, event, sparse, jitconn - -# Variable and Objects for object-oriented JAX transformations -from .oo_transform import * - -# environment settings -from .modes import * -from .environment import * -from .scales import * -from .others import * - -# high-level numpy operations -from . import fft -from . import linalg -from . import random - -# taichi operations -from . import tifunc - -# others -from . import sharding - -import jax.numpy as jnp -from jax import config - -del jnp, config - -from brainpy._src.math.surrogate._compt import ( - spike_with_sigmoid_grad as spike_with_sigmoid_grad, - spike_with_linear_grad as spike_with_linear_grad, - spike_with_gaussian_grad as spike_with_gaussian_grad, - spike_with_mg_grad as spike_with_mg_grad, -) - -from brainpy._src.math import defaults -from brainpy._src.deprecations import deprecation_getattr -__deprecations = { - "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", - sparse.seg_matmul), - 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_homo instead.", - jitconn.event_mv_prob_homo), - 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", - jitconn.event_mv_prob_uniform), - 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.event_mv_prob_normal instead.", - jitconn.event_mv_prob_normal), - 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_homo instead.", - jitconn.mv_prob_homo), - 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_uniform instead.", - jitconn.mv_prob_uniform), - 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " - "Use brainpy.math.jitconn.mv_prob_normal instead.", - jitconn.mv_prob_normal), - 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " - "Use brainpy.math.sparse.csrmv instead.", - sparse.csrmv), - 'cusparse_coo_matvec': ("brainpy.math.cusparse_coo_matvec is deprecated. " - "Use brainpy.math.sparse.coomv instead.", - sparse.coomv), - 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " - "Use brainpy.math.sparse.coo_to_csr instead.", - sparse.coo_to_csr), - 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " - "Use brainpy.math.sparse.csr_to_coo instead.", - sparse.csr_to_coo), - 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " - "Use brainpy.math.sparse.csr_to_dense instead.", - sparse.csr_to_dense), - 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " - "Use brainpy.math.event.csr_to_dense instead.", - event.csrmv), - 'event_info': ("brainpy.math.event_info is deprecated. " - "Use brainpy.math.event.info instead.", - event.info), -} - -__getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) -del deprecation_getattr, defaults +# -*- coding: utf-8 -*- + +# data structure +from .ndarray import * +from .delayvars import * +from .interoperability import * +from .datatypes import * +from .compat_numpy import * +from .compat_tensorflow import * +from .compat_pytorch import * +from .einops import * + +# functions +from .activations import * +from . import activations + +# operators +from .pre_syn_post import * +from .op_register import * +from . import surrogate, event, sparse, jitconn + +# Variable and Objects for object-oriented JAX transformations +from .oo_transform import * + +# environment settings +from .modes import * +from .environment import * +from .scales import * +from .others import * + +# high-level numpy operations +from . import fft +from . import linalg +from . import random + +# taichi operations +from . import tifunc + +# others +from . import sharding + +import jax.numpy as jnp +from jax import config + +del jnp, config + +from brainpy._src.math.surrogate._compt import ( + spike_with_sigmoid_grad as spike_with_sigmoid_grad, + spike_with_linear_grad as spike_with_linear_grad, + spike_with_gaussian_grad as spike_with_gaussian_grad, + spike_with_mg_grad as spike_with_mg_grad, +) + +from brainpy._src.math import defaults +from brainpy._src.deprecations import deprecation_getattr +from brainpy._src.dependency_check import import_taichi, import_numba + +import_taichi(error_if_not_found=False) +import_numba(error_if_not_found=False) + +__deprecations = { + "sparse_matmul": ("brainpy.math.sparse_matmul is deprecated. Use brainpy.math.sparse.seg_matmul instead.", + sparse.seg_matmul), + 'csr_matvec': ("brainpy.math.csr_matvec is deprecated. Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'event_matvec_prob_conn_homo_weight': ("brainpy.math.event_matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_homo instead.", + jitconn.event_mv_prob_homo), + 'event_matvec_prob_conn_uniform_weight': ("brainpy.math.event_matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_uniform instead.", + jitconn.event_mv_prob_uniform), + 'event_matvec_prob_conn_normal_weight': ("brainpy.math.event_matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.event_mv_prob_normal instead.", + jitconn.event_mv_prob_normal), + 'matvec_prob_conn_homo_weight': ("brainpy.math.matvec_prob_conn_homo_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_homo instead.", + jitconn.mv_prob_homo), + 'matvec_prob_conn_uniform_weight': ("brainpy.math.matvec_prob_conn_uniform_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_uniform instead.", + jitconn.mv_prob_uniform), + 'matvec_prob_conn_normal_weight': ("brainpy.math.matvec_prob_conn_normal_weight is deprecated. " + "Use brainpy.math.jitconn.mv_prob_normal instead.", + jitconn.mv_prob_normal), + 'cusparse_csr_matvec': ("brainpy.math.cusparse_csr_matvec is deprecated. " + "Use brainpy.math.sparse.csrmv instead.", + sparse.csrmv), + 'coo_to_csr': ("brainpy.math.coo_to_csr is deprecated. " + "Use brainpy.math.sparse.coo_to_csr instead.", + sparse.coo_to_csr), + 'csr_to_coo': ("brainpy.math.csr_to_coo is deprecated. " + "Use brainpy.math.sparse.csr_to_coo instead.", + sparse.csr_to_coo), + 'csr_to_dense': ("brainpy.math.csr_to_dense is deprecated. " + "Use brainpy.math.sparse.csr_to_dense instead.", + sparse.csr_to_dense), + 'event_csr_matvec': ("brainpy.math.event_csr_matvec is deprecated. " + "Use brainpy.math.event.csr_to_dense instead.", + event.csrmv), +} + +__getattr__ = deprecation_getattr(__name__, __deprecations, redirects=defaults.__all__, redirect_module=defaults) +del deprecation_getattr, defaults diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index 3b0c3f51..e4570f6f 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -12,7 +12,7 @@ arccos as arccos, acosh as acosh, arccosh as arccosh, - # add as add, + add as add, addcdiv as addcdiv, addcmul as addcmul, angle as angle, diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 0a17cae7..02e98b8f 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,5 +1,3 @@ - from brainpy._src.math.event import ( csrmv as csrmv, - info as info, ) diff --git a/brainpy/math/jitconn.py b/brainpy/math/jitconn.py index 90a028b7..a87d27d5 100644 --- a/brainpy/math/jitconn.py +++ b/brainpy/math/jitconn.py @@ -1,10 +1,10 @@ -from brainpy._src.math.jitconn import ( - event_mv_prob_homo as event_mv_prob_homo, - event_mv_prob_uniform as event_mv_prob_uniform, - event_mv_prob_normal as event_mv_prob_normal, - - mv_prob_homo as mv_prob_homo, - mv_prob_uniform as mv_prob_uniform, - mv_prob_normal as mv_prob_normal, -) - +from brainpy._src.math.jitconn import ( + event_mv_prob_homo as event_mv_prob_homo, + event_mv_prob_uniform as event_mv_prob_uniform, + event_mv_prob_normal as event_mv_prob_normal, + + mv_prob_homo as mv_prob_homo, + mv_prob_uniform as mv_prob_uniform, + mv_prob_normal as mv_prob_normal, +) + diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 7654731d..548a987d 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -59,7 +59,3 @@ eval_shape as eval_shape, ) -from brainpy._src.math.object_transform.variables import ( - VariableStack as VariableStack, -) - diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index a48268ef..c0fcb67a 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -1,14 +1,12 @@ -# -*- coding: utf-8 -*- - - -from brainpy._src.math.op_register import ( - CustomOpByNumba, - compile_cpu_signature_with_numba, - clean_caches, - check_kernels_count, -) - -from brainpy._src.math.op_register.base import XLACustomOp -from brainpy._src.math.op_register.ad_support import defjvp - - +# -*- coding: utf-8 -*- +from brainpy._src.math.op_register import ( + CustomOpByNumba, + compile_cpu_signature_with_numba, + clean_caches, + check_kernels_count, +) + +from brainpy._src.math.op_register.base import XLACustomOp +from brainpy._src.math.op_register.ad_support import defjvp + + diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 1380a9e9..aa86679e 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,8 +1,9 @@ -from brainpy._src.math.sparse import ( - csrmv, - coomv, +from brainpy._src.math.sparse import ( seg_matmul, +) +from brainpy._src.math.sparse import ( + csrmv, csr_to_dense as csr_to_dense, csr_to_coo as csr_to_coo, diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py index 63f3cbe4..bea49c22 100644 --- a/brainpy/math/tifunc.py +++ b/brainpy/math/tifunc.py @@ -1,26 +1,25 @@ -# -*- coding: utf-8 -*- - -from brainpy._src.math.tifunc import ( - taichi_lcg_rand, - - # warp reduction primitives - warp_reduce_sum, - - # random number generator - lfsr88_key, - lfsr88_next_key, - lfsr88_normal, - lfsr88_randn, - lfsr88_random_integers, - lfsr88_randint, - lfsr88_uniform, - lfsr88_rand, - lfsr113_key, - lfsr113_next_key, - lfsr113_normal, - lfsr113_randn, - lfsr113_random_integers, - lfsr113_randint, - lfsr113_uniform, - lfsr113_rand -) +# -*- coding: utf-8 -*- + +from brainpy._src.math.tifunc import ( + + # warp reduction primitives + warp_reduce_sum, + + # random number generator + lfsr88_key, + lfsr88_next_key, + lfsr88_normal, + lfsr88_randn, + lfsr88_random_integers, + lfsr88_randint, + lfsr88_uniform, + lfsr88_rand, + lfsr113_key, + lfsr113_next_key, + lfsr113_normal, + lfsr113_randn, + lfsr113_random_integers, + lfsr113_randint, + lfsr113_uniform, + lfsr113_rand +) diff --git a/brainpy/tools.py b/brainpy/tools.py index 233269dc..0f3a4c0e 100644 --- a/brainpy/tools.py +++ b/brainpy/tools.py @@ -45,9 +45,4 @@ ) -from brainpy._src.tools.functions import ( - compose as compose, - pipe as pipe, -) - diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst index 0b78315a..5c8cba0f 100644 --- a/docs/advanced_tutorials.rst +++ b/docs/advanced_tutorials.rst @@ -3,52 +3,13 @@ Advanced Tutorials This section contains tutorials that illustrate more advanced features of BrainPy. -Advanced Math -------------- .. toctree:: - :maxdepth: 1 - - tutorial_advanced/compilation.ipynb - tutorial_advanced/differentiation.ipynb - - -Interoperation --------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/integrate_flax_into_brainpy.ipynb - tutorial_advanced/integrate_bp_lif_into_flax.ipynb - tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb - - -Brain Dynamics Dedicated Operators ----------------------------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/operator_custom_with_numba.ipynb - tutorial_advanced/operator_custom_with_taichi.ipynb - - -Developer Guides ----------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/contributing.md - - -Others ------- - -.. toctree:: - :maxdepth: 1 - - tutorial_advanced/advanced_lowdim_analysis.ipynb + :maxdepth: 2 + tutorial_advanced/1_advanced_math.rst + tutorial_advanced/2_interoperation.rst + tutorial_advanced/3_dedicated_operators.rst + tutorial_advanced/4_developer_guides.rst + tutorial_advanced/5_others.rst diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 9ed9cf46..754e0d81 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -77,5 +77,4 @@ Helpers for Object-oriented Transformations :template: classtemplate.rst eval_shape - VariableStack diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index 2e0bb190..46ce3822 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -10,285 +10,71 @@ Installation Linux, and MacOS. It only relies on Python libraries. -Installation with pip ---------------------- +Minimum requirements +-------------------- -You can install ``BrainPy`` from the `pypi `_. -To do so, use: +To install brainpy with minimum requirements (only depends on ``jax``), you can use: .. code-block:: bash - pip install brainpy - -To update the latest BrainPy, you can use - -.. code-block:: bash - - pip install -U brainpy - - -If you want to install the pre-release version (the latest development version) -of BrainPy, you can use: - -.. code-block:: bash - - pip install --pre brainpy - - - -Installation from source ------------------------- - -If you decide not to use ``pip``, you can install ``BrainPy`` from -`GitHub `_, -or `OpenI `_. - -To do so, use: - -.. code-block:: bash - - pip install git+https://github.com/PKU-NIP-Lab/BrainPy + pip install brainpy[cpu_mini] # for CPU # or - pip install git+https://git.openi.org.cn/OpenI/BrainPy + pip install brainpy[cuda_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for GPU (Linux only) -Dependency 1: NumPy --------------------------------- -In order to make BrainPy work normally, users should install -several dependent Python packages. +CPU with all dependencies +------------------------- -The basic function of ``BrainPy`` only relies on `NumPy`_, which is very -easy to install through ``pip`` or ``conda``: +To install a CPU-only version of BrainPy, which might be useful for doing local development on a laptop, you can run .. code-block:: bash - pip install numpy - - # or - - conda install numpy - -Dependency 2: JAX ------------------ - -BrainPy relies on `JAX`_. JAX is a high-performance JIT compiler which enables -users to run Python code on CPU, GPU, and TPU devices. Core functionalities of -BrainPy (>=2.0.0) have been migrated to the JAX backend. - -Linux -^^^^^ - -Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or -later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS -systems are available at + pip install brainpy[cpu] -- for CPU: https://storage.googleapis.com/jax-releases/jax_releases.html -- for GPU: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html -If you want to install a CPU-only version of `jax` and `jaxlib`, you can run +GPU with all dependencies +------------------------- -.. code-block:: bash - - pip install --upgrade "jax[cpu]" - -If you want to install JAX with both CPU and NVidia GPU support, you must first install -`CUDA`_ and `CuDNN`_, if they have already been installed. Next, run +BrainPy supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer. +To install a GPU-only version of BrainPy, you can run .. code-block:: bash - # CUDA 12 installation - # Note: wheels only available on linux. - pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - - # CUDA 11 installation - # Note: wheels only available on linux. - pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - -In the event of a version mismatch error with JAX, such as encountering an error message like: - -.. code-block:: text + pip install brainpy[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0 + pip install brainpy[cuda11] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0 - CUDA backend failed to initialize: Found CUDA version 12000, but JAX was built against version 12020, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) -You will need to employ an alternative installation method that aligns with your environment's CUDA version. This can be achieved using the following commands: -.. code-block:: bash +``brainpylib`` +-------------- - # CUDA 12 installation - pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - # CUDA 11 installation - pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html +``brainpylib`` defines a set of useful operators for building and simulating spiking neural networks. -Alternatively, you can download the preferred release ".whl" file for jaxlib -from the above release links, and install it via ``pip``: +To install the ``brainpylib`` package on CPU devices, you can run .. code-block:: bash - pip install xxx-0.4.15-xxx.whl - - pip install jax==0.4.15 - -.. note:: - - Note that the versions of jaxlib and jax should be consistent. - - For example, if you are using jax==0.4.15, you would better install jax==0.4.15. - + pip install brainpylib -MacOS -^^^^^ -If you are using macOS Intel, we recommend you first to install the Miniconda Intel installer: +To install the ``brainpylib`` package on CUDA 11, you can run -1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.pkg -2. Then click the downloaded package and install it. - - -If you are using the latest M1 macOS version, you'd better to install the Miniconda M1 installer: - - -1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.pkg -2. Then click the downloaded package and install it. - - -Finally, you can install `jax` and `jaxlib` as the same as the Linux platform. .. code-block:: bash - pip install --upgrade "jax[cpu]" - - - -Windows -^^^^^^^ - -For **Windows** users with Python >= 3.9, `jax` and `jaxlib` can be installed -directly from the PyPi channel. - -.. code-block:: bash + pip install brainpylib-cu11x - pip install jax jaxlib +To install the ``brainpylib`` package on CUDA 12, you can run -For **Windows** users with Python <= 3.8, `jax` and `jaxlib` can be installed -from the community supports. Specifically, you can install `jax` and `jaxlib` through: .. code-block:: bash - pip install "jax[cpu]" -f https://whls.blob.core.windows.net/unstable/index.html - -If you are using GPU, you can install GPU-versioned wheels through: - -.. code-block:: bash - - pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html - -Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by -downloading binary releases of JAX for Windows from -https://whls.blob.core.windows.net/unstable/index.html . -Then install it via ``pip``: - -.. code-block:: bash - - pip install xxx-0.4.15-xxx.whl - - pip install jax==0.4.15 - -WSL -^^^ - -Moreover, for Windows 10+ system, we recommend using `Windows Subsystem for Linux (WSL)`_. -The installation guide can be found in -`WSL Installation Guide for Windows 10/11 `_. -Then, you can install JAX in WSL just like the installation step in Linux/MacOs. - - -Dependency 3: brainpylib ------------------------- - -Many customized operators in BrainPy are implemented in ``brainpylib``. -``brainpylib`` can also be installed from pypi according to your devices. -For windows, Linux and MacOS users, ``brainpylib`` supports CPU operators. -You can install CPU-version `brainpylib` by: - -.. code-block:: bash - - # CPU installation - pip install --upgrade brainpylib - -For Nvidia GPU users, ``brainpylib`` only support Linux system and WSL2 subsystem. You can install the CUDA-version by using: - -.. code-block:: bash - - # CUDA 12 installation - pip install --upgrade brainpylib-cu12x - -.. code-block:: bash - - # CUDA 11 installation - pip install --upgrade brainpylib-cu11x - -Dependency 4: taichi ------------------------- -Now BrainPy supports customized operators implemented in `taichi`_. You can install the latest version of `taichi`_ by: - -.. code-block:: bash - - pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly - -.. _taichi: https://www.taichi-lang.org - -And you can try it in the `operator custom with taichi <../tutorial_advanced/operator_custom_with_taichi.html>`_ tutorial page -Attention: customized operators is still in the experimental stage. If you meet any problems, please contact us through the issue page. - -Running BrainPy with docker ------------------------- - -If you want to use BrainPy in docker, you can use the following command to pull the docker image: - -.. code:: bash - - docker pull brainpy/brainpy:latest - -You can then run the docker image by: - -.. code:: bash - - docker run -it --platform linux/amd64 brainpy/brainpy:latest - -Please notice that BrainPy docker image is based on the `ubuntu22.04` image, so it only support CPU version of BrainPy. - - -Running BrainPy online with binder ----------------------------------- - -Click on the following link to launch the Binder environment with the -BrainPy repository: - -|image1| - -Wait for the Binder environment to build. This might take a few moments. - -Once the environment is ready, you'll be redirected to a Jupyter -notebook interface within your web browser. - -.. |image1| image:: https://camo.githubusercontent.com/581c077bdbc6ca6899c86d0acc6145ae85e9d80e6f805a1071793dbe48917982/68747470733a2f2f6d7962696e6465722e6f72672f62616467655f6c6f676f2e737667 - :target: https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main - - -.. _NumPy: https://numpy.org/ -.. _Matplotlib: https://matplotlib.org/ -.. _JAX: https://github.com/google/jax -.. _Windows Subsystem for Linux (WSL): https://docs.microsoft.com/en-us/windows/wsl/about -.. _build JAX from source: https://jax.readthedocs.io/en/latest/developer.html -.. _SymPy: https://github.com/sympy/sympy -.. _Numba: https://numba.pydata.org/ -.. _CUDA: https://developer.nvidia.com/cuda-downloads -.. _CuDNN: https://developer.nvidia.com/CUDNN + pip install brainpylib-cu12x diff --git a/docs/toolboxes.rst b/docs/toolboxes.rst index cc3a3857..11bf5311 100644 --- a/docs/toolboxes.rst +++ b/docs/toolboxes.rst @@ -1,16 +1,7 @@ BDP Toolboxes ================== - - - This section contains detailed toolboxes BrainPy uses for brain dynamics modeling. - - -Differential Equations ------------------------ - - .. toctree:: :maxdepth: 1 @@ -19,34 +10,11 @@ Differential Equations tutorial_toolbox/fde_numerical_solvers tutorial_toolbox/dde_numerical_solvers tutorial_toolbox/joint_equations - - -Toolbox for Modeling -------------------- - -.. toctree:: - :maxdepth: 1 - tutorial_toolbox/synaptic_connections tutorial_toolbox/synaptic_weights - tutorial_toolbox/inputs - - -Toolbox for Training --------------------- - -.. toctree:: - :maxdepth: 1 - tutorial_toolbox/optimizers + tutorial_toolbox/state_saving_and_loading.ipynb + tutorial_toolbox/state_resetting.ipynb tutorial_toolbox/surrogate_gradient + tutorial_toolbox/inputs - -State Resetting, Saving and Loading ------------------------------------ - -.. toctree:: - :maxdepth: 1 - - tutorial_toolbox/state_saving_and_loading.ipynb - tutorial_toolbox/state_resetting.ipynb \ No newline at end of file diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 57d18332..7c9a1c87 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -3,76 +3,11 @@ BDP Tutorials This section contains tutorials on how to use BrainPy to accomplish model building, simulation, training, and analysis. - -Math Foundation ---------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_math/variables - tutorial_math/control_flows - tutorial_math/Numpy_like_Operations.ipynb - tutorial_math/Dedicated_Operators.ipynb - tutorial_math/einops_in_brainpy.ipynb - - -Model Building with Existing Modules ------------------------------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_building/overview_of_dynamic_model - tutorial_building/build_conductance_neurons_v2.ipynb - tutorial_building/phenon_synapse_models.ipynb - tutorial_building/kinetic_synapse_models.ipynb - tutorial_building/build_network_models - - -Model Building by Customizing New Modules ------------------------------------------ - -.. toctree:: - :maxdepth: 1 - - tutorial_building/customize_neuron_models - tutorial_building/customize_synapse_models - tutorial_building/how_to_customze_a_synapse.ipynb - - -Model Simulation ----------------- - -.. toctree:: - :maxdepth: 1 - - tutorial_simulation/simulation_dsrunner.ipynb - tutorial_simulation/parallel_for_parameter_exploration.ipynb - tutorial_simulation/monitor_per_multiple_steps.ipynb - - -Model Training --------------- - -This tutorial shows how to train a dynamical system from data or task. - -.. toctree:: - :maxdepth: 1 - - tutorial_training/build_training_models.ipynb - tutorial_training/offline_training.ipynb - tutorial_training/online_training.ipynb - tutorial_training/bp_training.ipynb - tutorial_training/esn_introduction.ipynb - - -Model Analysis --------------- - .. toctree:: - :maxdepth: 1 + :maxdepth: 2 - tutorial_analysis/lowdim_analysis - tutorial_analysis/highdim_analysis - tutorial_analysis/decision_making_model + tutorial_math/index + tutorial_building/index + tutorial_simulation/index + tutorial_training/index + tutorial_analysis/index diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index 9c7daff5..f9852745 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -228,7 +228,7 @@ def __init__(self): ) def update(self, input): - spk = self.delay.at('delay') + spk = self.delay.at('I') self.E(self.syn1(spk[:3200])) self.I(self.syn2(spk[3200:])) self.delay(self.N(input)) diff --git a/examples/dynamics_training/integrator_rnn.py b/examples/dynamics_training/integrator_rnn.py index fc36845e..aeaf0c41 100644 --- a/examples/dynamics_training/integrator_rnn.py +++ b/examples/dynamics_training/integrator_rnn.py @@ -30,7 +30,7 @@ def train_data(): class RNN(bp.DynamicalSystem): def __init__(self, num_in, num_hidden): super(RNN, self).__init__() - self.rnn = bp.layers.RNNCell(num_in, num_hidden, train_state=True) + self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True) self.out = bp.layers.Dense(num_hidden, 1) def update(self, x): @@ -49,7 +49,7 @@ def loss(predictions, targets, l2_reg=2e-4): # define optimizer -lr = bp.optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975) +lr = bp.optim.ExponentialDecayLR(lr=0.025, decay_steps=1, decay_rate=0.99975) opt = bp.optim.Adam(lr=lr, eps=1e-1) # create a trainer diff --git a/requirements-dev-raw.txt b/requirements-dev-raw.txt new file mode 100644 index 00000000..99361efa --- /dev/null +++ b/requirements-dev-raw.txt @@ -0,0 +1,12 @@ +numpy +jax +jaxlib +matplotlib +msgpack +tqdm +pathos + + +# test requirements +pytest +absl-py diff --git a/requirements-dev.txt b/requirements-dev.txt index 0e475e83..98398ae2 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,4 @@ numpy -numba brainpylib jax jaxlib @@ -7,7 +6,9 @@ matplotlib msgpack tqdm pathos -taichi==1.7.0 +taichi +numba + # test requirements pytest diff --git a/requirements.txt b/requirements.txt index 02fdebe8..ab5665e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ numpy jax tqdm -numba -taichi==1.7.0 diff --git a/setup.py b/setup.py index d7fd45e3..885bbf57 100644 --- a/setup.py +++ b/setup.py @@ -56,8 +56,8 @@ author='BrainPy Team', author_email='chao.brain@qq.com', packages=packages, - python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'numba', 'taichi==1.7.0'], + python_requires='>=3.9', + install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", @@ -68,11 +68,12 @@ 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', ], extras_require={ - 'cpu': ['jaxlib>=0.4.13', 'brainpylib'], - 'cuda': ['jax[cuda]', 'brainpylib-cu12x'], - 'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'], - 'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'], - 'tpu': ['jax[tpu]'], + 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'], + 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib-cu11x', 'numba', 'taichi==1.7.0'], + 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib-cu12x', 'numba', 'taichi==1.7.0'], + 'tpu': ['jaxlib[tpu]', 'numba',], + 'cpu_mini': ['jaxlib>=0.4.13'], + 'cuda_mini': ['jaxlib[cuda12_pip]'], }, keywords=('computational neuroscience, ' 'brain-inspired computation, ' @@ -89,6 +90,7 @@ 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', 'Intended Audience :: Science/Research', 'License :: OSI Approved :: Apache Software License', 'Topic :: Scientific/Engineering :: Bio-Informatics', From 8460a1d3ff09fe6e9a776ad65398c8a2799ee496 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 15:17:30 +0800 Subject: [PATCH 03/21] `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` (#639) * `clear_buffer_memory()` support clearing `array`, `compilation`, and `names` * add doc * upgrade * upgrade * try to fix * update CI --- .github/workflows/CI.yml | 192 ++--- .../synapses/tests/test_abstract_synapses.py | 254 +++--- brainpy/_src/math/environment.py | 19 +- brainpy/_src/math/jitconn/_event_matvec.py | 742 +++++++++++++++++- brainpy/_src/math/object_transform/naming.py | 7 +- 5 files changed, 927 insertions(+), 287 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 95bd8eaf..d29b07eb 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -11,6 +11,12 @@ on: branches: - '**' # matches every branch + +permissions: + contents: read # to fetch code + actions: write # to cancel previous workflows + + #on: # push: # branches: [ master ] @@ -27,6 +33,10 @@ jobs: python-version: [ "3.9", "3.10", "3.11"] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -35,16 +45,9 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | cd brainpy @@ -82,40 +85,6 @@ jobs: pytest _src/ -# test_linux_py37: -# runs-on: ubuntu-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# - test_macos: runs-on: macos-latest strategy: @@ -124,6 +93,10 @@ jobs: python-version: ["3.9", "3.10", "3.11"] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -132,16 +105,40 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install - - name: Lint with flake8 + - name: Test with pytest run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + cd brainpy + pytest -n auto --tb=short _src/ + + + test_windows: + strategy: + fail-fast: false + matrix: + os: [ win-2019-16core ] + arch: [ AMD64 ] + python-version: ["3.9", "3.10", "3.11"] + runs-on: ${{ matrix.os }} + + steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements-dev.txt + pip uninstall brainpy -y + python setup.py install - name: Test with pytest run: | cd brainpy @@ -178,104 +175,3 @@ jobs: cd brainpy pytest _src/ -# test_macos_py37: -# runs-on: macos-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: [ "3.7" ] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi -# pip install jax==0.3.25 -# pip install jaxlib==0.3.25 -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ -# - - -# test_windows: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.9", "3.10", "3.11"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install -r requirements-dev.txt -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd brainpy -# pytest _src/ - - -# test_windows_py37: -# runs-on: windows-latest -# strategy: -# fail-fast: false -# matrix: -# python-version: ["3.7"] -# -# steps: -# - uses: actions/checkout@v4 -# - name: Set up Python ${{ matrix.python-version }} -# uses: actions/setup-python@v5 -# with: -# python-version: ${{ matrix.python-version }} -# - name: Install dependencies -# run: | -# python -m pip install --upgrade pip -# python -m pip install flake8 pytest -# python -m pip install numpy>=1.21.0 -# python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver -# python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz -# python -m pip install -r requirements-dev.txt -# python -m pip install tqdm brainpylib -# pip uninstall brainpy -y -# python setup.py install -# - name: Lint with flake8 -# run: | -# # stop the build if there are Python syntax errors or undefined names -# flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics -# # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide -# flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# - name: Test with pytest -# run: | -# cd examples -# pytest ../brainpy/ diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index c3936f68..6db945ff 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -1,130 +1,124 @@ -# -*- coding: utf-8 -*- - -import pytest -from absl.testing import parameterized - -import brainpy as bp -import brainpy.math as bm -from brainpy._src.dynold.synapses import abstract_models -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) - - -class Test_Abstract_Synapse(parameterized.TestCase): - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_all2all_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(bp.synapses, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_one2one_synapse(self, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - comp_type=['sparse', 'dense'], - name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_sparse_synapse(self, comp_type, name, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5) - post_neu = bp.neurons.LIF(5) - syn_model = getattr(abstract_models, name) - syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'syn.g', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - expected_shape = (100, 5) - if isinstance(mode, bm.BatchingMode): - expected_shape = (mode.batch_size,) + expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) - self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) - bm.clear_buffer_memory() - - @parameterized.product( - post_ref_key=[None, 'refractory'], - stp=[None, bp.synplast.STD(), bp.synplast.STP()], - mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] - ) - def test_delta_synapse(self, post_ref_key, stp, mode): - bm.random.seed() - with bm.environment(mode=mode): - pre_neu = bp.neurons.LIF(5, ref_var=True) - post_neu = bp.neurons.LIF(3, ref_var=True) - syn = bp.synapses.Delta(pre_neu, post_neu, - conn=bp.conn.All2All(), - post_ref_key=post_ref_key, - stp=stp, ) - net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) - - # 运行模拟 - runner = bp.DSRunner(net, - monitors=['pre.V', 'post.V'], - inputs=('pre.input', 35.)) - runner(10.) - - pre_expected_shape = (100, 5) - post_expected_shape = (100, 3) - if isinstance(mode, bm.BatchingMode): - pre_expected_shape = (mode.batch_size,) + pre_expected_shape - post_expected_shape = (mode.batch_size,) + post_expected_shape - self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) - self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) - bm.clear_buffer_memory() +# -*- coding: utf-8 -*- + + +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dynold.synapses import abstract_models + + +class Test_Abstract_Synapse(parameterized.TestCase): + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_all2all_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(bp.synapses, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.All2All(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, monitors=['pre.V', 'syn.g', 'post.V'], inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_one2one_synapse(self, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.One2One(), stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + comp_type=['sparse', 'dense'], + name=['Exponential', 'DualExponential', 'Alpha', 'NMDA'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_sparse_synapse(self, comp_type, name, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5) + post_neu = bp.neurons.LIF(5) + syn_model = getattr(abstract_models, name) + syn = syn_model(pre_neu, post_neu, conn=bp.conn.FixedProb(0.1), comp_method=comp_type, stp=stp) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'syn.g', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + expected_shape = (100, 5) + if isinstance(mode, bm.BatchingMode): + expected_shape = (mode.batch_size, ) + expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, expected_shape) + self.assertTupleEqual(runner.mon['syn.g'].shape, expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, expected_shape) + bm.clear_buffer_memory() + + @parameterized.product( + post_ref_key=[None, 'refractory'], + stp=[None, bp.synplast.STD(), bp.synplast.STP()], + mode=[bm.nonbatching_mode, bm.BatchingMode(5), bm.TrainingMode(5)] + ) + def test_delta_synapse(self, post_ref_key, stp, mode): + bm.random.seed() + with bm.environment(mode=mode): + pre_neu = bp.neurons.LIF(5, ref_var=True) + post_neu = bp.neurons.LIF(3, ref_var=True) + syn = bp.synapses.Delta(pre_neu, post_neu, + conn=bp.conn.All2All(), + post_ref_key=post_ref_key, + stp=stp, ) + net = bp.Network(pre=pre_neu, syn=syn, post=post_neu) + + # 运行模拟 + runner = bp.DSRunner(net, + monitors=['pre.V', 'post.V'], + inputs=('pre.input', 35.)) + runner(10.) + + pre_expected_shape = (100, 5) + post_expected_shape = (100, 3) + if isinstance(mode, bm.BatchingMode): + pre_expected_shape = (mode.batch_size,) + pre_expected_shape + post_expected_shape = (mode.batch_size,) + post_expected_shape + self.assertTupleEqual(runner.mon['pre.V'].shape, pre_expected_shape) + self.assertTupleEqual(runner.mon['post.V'].shape, post_expected_shape) + bm.clear_buffer_memory() \ No newline at end of file diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 668f837c..7827dfed 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -2,6 +2,7 @@ import functools +import gc import inspect import os import re @@ -16,6 +17,7 @@ from . import modes from . import scales from . import defaults +from .object_transform import naming from brainpy._src.dependency_check import import_taichi ti = import_taichi(error_if_not_found=False) @@ -681,7 +683,9 @@ def set_host_device_count(n): def clear_buffer_memory( platform: str = None, array: bool = True, - compilation: bool = False + transform: bool = True, + compilation: bool = False, + object_name: bool = False, ): """Clear all on-device buffers. @@ -698,9 +702,13 @@ def clear_buffer_memory( platform: str The device to clear its memory. array: bool - Clear all buffer array. + Clear all buffer array. Default is True. compilation: bool - Clear compilation cache. + Clear compilation cache. Default is False. + transform: bool + Clear transform cache. Default is True. + object_name: bool + Clear name cache. Default is True. """ if array: @@ -708,6 +716,11 @@ def clear_buffer_memory( buf.delete() if compilation: jax.clear_caches() + if transform: + naming.clear_stack_cache() + if object_name: + naming.clear_name_cache() + gc.collect() def disable_gpu_memory_preallocation(release_memory: bool = True): diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 976b72b9..ac62bbfa 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -45,9 +45,747 @@ def event_mv_prob_homo( if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + +event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + outdim_parallel=outdim_parallel) + + +### BRAINPYLIB ### + +def event_mv_prob_homo_brainpylib( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + weight = jnp.atleast_1d(jnp.asarray(weight)) + conn_prob = jnp.atleast_1d(jnp.asarray(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + r = event_mv_prob_homo_p.bind(events, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + return r + + +event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ + + +def event_mv_prob_uniform_brainpylib( + events: jax.Array, + w_low: float, + w_high: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + w_low = jnp.atleast_1d(as_jax(w_low)) + w_high = jnp.atleast_1d(as_jax(w_high)) + conn_prob = jnp.atleast_1d(as_jax(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + return event_mv_prob_uniform_p.bind(events, + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ + + +def event_mv_prob_normal_brainpylib( + events: jax.Array, + w_mu: float, + w_sigma: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + events = as_jax(events) + w_mu = jnp.atleast_1d(as_jax(w_mu)) + w_sigma = jnp.atleast_1d(as_jax(w_sigma)) + conn_prob = jnp.atleast_1d(as_jax(conn_prob)) + clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) + with jax.ensure_compile_time_eval(): + if seed is None: + seed = int(np.random.randint(0, int(1e8))) + seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) + return event_mv_prob_normal_p.bind(events, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] + + +event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ + + +def _event_matvec_prob_homo_abstract( + events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _event_matvec_prob_homo_cpu_translation( + c, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_homo' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type + + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + weight, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(weight), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_homo_gpu_translation( + c, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_homo_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1], ) + + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type + + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, weight, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(weight), + c.get_shape(clen), + c.get_shape(seed)), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_homo_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, weight, clen, seed = primals + event_dot, weight_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_homo_p.bind(events, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + assert type(weight_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + if type(weight_dot) is ad.Zero: + if type(event_dot) is ad.Zero: + raise ValueError + dr = mv_prob_homo_p.bind(event_dot, + weight, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + elif type(event_dot) is ad.Zero: + dr = mv_prob_homo_p.bind(events, + weight_dot, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + else: + dr = mv_prob_homo_p.bind(event_dot, + weight_dot, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, dr + + +def _event_matvec_prob_homo_transpose( + ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(weight) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_homo_p.bind(ct[0], + weight, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, weight, clen, seed + + +event_mv_prob_homo_p = Primitive('event_mv_prob_homo') +event_mv_prob_homo_p.multiple_results = True +event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) +event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation +ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp +ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose +register_general_batching(event_mv_prob_homo_p) + + +def _event_matvec_prob_uniform_abstract( + events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + _w_low_dtype = _get_dtype(w_low) + _w_high_dtype = _get_dtype(w_low) + assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' + assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' + assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if w_low.ndim != 1: + raise ValueError('w_low must be a 1D scalar.') + if w_high.ndim != 1: + raise ValueError('w_high must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('clen must be a 1D scalar.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + + if not isinstance(transpose, bool): + raise ValueError('transpose must be a boolean value.') + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be a boolean value.') + assert w_low.dtype == w_high.dtype + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _event_matvec_prob_uniform_cpu_translation( + c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + w_low, + w_high, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_low), + c.get_shape(w_high), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_uniform_gpu_translation( + c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1]) + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, w_low, w_high, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_low), + c.get_shape(w_high), + c.get_shape(clen), + c.get_shape(seed),), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_uniform_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, w_low, w_high, clen, seed = primals + events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_uniform_p.bind(events, + w_low, + w_high, + clen, + seed, + shape=shape, + outdim_parallel=outdim_parallel, + transpose=transpose) + assert type(w_low_dot) is ad.Zero + assert type(w_high_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + r_dot = mv_prob_uniform_p.bind(events_dot, + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, r_dot + + +def _event_matvec_prob_uniform_transpose( + ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(w_low) is not ad.UndefinedPrimal + assert type(w_high) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_uniform_p.bind(ct[0], + w_low, + w_high, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, w_low, w_high, clen, seed + + +event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') +event_mv_prob_uniform_p.multiple_results = True +event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) +event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation +register_general_batching(event_mv_prob_uniform_p) +ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp +ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose + + +def _event_matvec_prob_normal_abstract( + events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] + _w_mu_dtype = _get_dtype(w_mu) + _w_sigma_dtype = _get_dtype(w_sigma) + assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' + assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' + assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] + + if w_mu.ndim != 1: + raise ValueError('w_mu should be a 1D scalar.') + if w_sigma.ndim != 1: + raise ValueError('w_sigma should be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('clen should be a 1D scalar.') + if events.ndim != 1: + raise ValueError('events should be a 1D vector.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + assert w_mu.dtype == w_sigma.dtype + + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be a boolean value.') + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be a boolean value.') + + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) + return [out] + + +def _get_types(event_shape): + event_type = event_shape.element_type() + if event_type == jnp.bool_: + event_type = b'_bool' + out_dtype = dtypes.canonicalize_dtype(float) + elif event_type == jnp.float32: + event_type = b'_float' + out_dtype = event_shape.element_type() + elif event_type == jnp.float64: + event_type = b'_double' + out_dtype = event_shape.element_type() + else: + raise TypeError + + if out_dtype == jnp.float32: + type_name = b'_float' + elif out_dtype == jnp.float64: + type_name = b'_double' + else: + raise TypeError + + return out_dtype, event_type, type_name + + +def _event_matvec_prob_normal_cpu_translation( + c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + import_brainpylib_cpu_ops() + n_row, n_col = (shape[1], shape[0]) if transpose else shape + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + if outdim_parallel: + fn = b'cpu_event_matvec_prob_normal' + type_name + event_type + else: + fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, + w_mu, + w_sigma, + clen, + seed, + xla_client.ops.ConstantLiteral(c, n_row), + xla_client.ops.ConstantLiteral(c, n_col)), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_mu), + c.get_shape(w_sigma), + c.get_shape(clen), + c.get_shape(seed), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), + xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + ) + + +def _event_matvec_prob_normal_gpu_translation( + c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + gpu_ops = import_brainpylib_gpu_ops() + if gpu_ops is None: + raise GPUOperatorNotFound(event_mv_prob_normal_p.name) + + out_dtype, event_type, type_name = _get_types(c.get_shape(events)) + + opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], + shape[0] if transpose else shape[1]) + if outdim_parallel: + fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type + else: + fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(events, w_mu, w_sigma, clen, seed), + operand_shapes_with_layout=(c.get_shape(events), + c.get_shape(w_mu), + c.get_shape(w_sigma), + c.get_shape(clen), + c.get_shape(seed)), + shape_with_layout=xla_client.Shape.tuple_shape( + ( + xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), + ) + ), + opaque=opaque, + ) + + +def _event_matvec_prob_normal_jvp( + primals, tangents, *, shape, transpose, outdim_parallel +): + events, w_mu, w_sigma, clen, seed = primals + events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents + r = event_mv_prob_normal_p.bind(events, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + assert type(w_mu_dot) is ad.Zero + assert type(w_sigma_dot) is ad.Zero + assert type(clen_dot) is ad.Zero + assert type(seed_dot) is ad.Zero + r_dot = mv_prob_normal_p.bind(events_dot, + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + return r, r_dot + + +def _event_matvec_prob_normal_transpose( + ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel +): + assert type(events) is ad.UndefinedPrimal + assert type(w_mu) is not ad.UndefinedPrimal + assert type(w_sigma) is not ad.UndefinedPrimal + assert type(clen) is not ad.UndefinedPrimal + assert type(seed) is not ad.UndefinedPrimal + + r = mv_prob_normal_p.bind(ct[0], + w_mu, + w_sigma, + clen, + seed, + shape=shape, + transpose=not transpose, + outdim_parallel=not outdim_parallel)[0] + return r, w_mu, w_sigma, clen, seed + + +event_mv_prob_normal_p = Primitive('event_mv_prob_normal') +event_mv_prob_normal_p.multiple_results = True +event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) +event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) +# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation +register_general_batching(event_mv_prob_normal_p) +ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp +ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose + + +### TAICHI ### + +def event_mv_prob_homo_taichi( + events: jax.Array, + weight: float, + conn_prob: float, + seed: Optional[int] = None, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, +) -> jax.Array: + r"""Perform the :math:`y=M@v` operation, + where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. + + This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations + on CPU and GPU devices. + + .. warning:: + + This API may change in the future. + + In this operation, :math:`M` is the random matrix with a connection probability + `conn_prob`, and at each connection the value is the same scalar `weight`. + + When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. + + .. note:: + + Note that the just-in-time generated :math:`M` (`transpose=False`) is + different from the generated :math:`M^T` (`transpose=True`). + + If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time + matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of + the speed compared with ``outdim_parallel=False``. + + Parameters + ---------- + events: Array, ndarray + The events. + weight: float + The value of the random matrix. + conn_prob: float + The connection probability. + shape: tuple of int + The matrix shape. + seed: int + The random number generation seed. + transpose: bool + Transpose the random matrix or not. + outdim_parallel: bool + Perform the parallel random generations along the out dimension or not. + It can be used to set the just-in-time generated :math:M^T: is the same + as the just-in-time generated :math:`M` when ``transpose=True``. + + Returns + ------- + out: Array, ndarray + The output of :math:`y = M @ v`. + """ events = as_jax(events) - if isinstance(weight, float): weight = as_jax(weight) - weight = jnp.atleast_1d(as_jax(weight)) + weight = as_jax(weight) + if jnp.ndim(weight) < 1: + weight = jnp.expand_dims(weight, axis=0) conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) if seed is None: diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 1c8ca6ef..6326929c 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- - +import gc import warnings from brainpy import errors @@ -11,6 +11,7 @@ _name2id = dict() _typed_names = {} +_fun2stack = dict() def check_name_uniqueness(name, obj): @@ -49,9 +50,6 @@ def clear_name_cache(ignore_warn=False): warnings.warn(f'All named models and their ids are cleared.', UserWarning) -_fun2stack = dict() - - def cache_stack(func, stack): _fun2stack[func] = stack @@ -59,6 +57,7 @@ def cache_stack(func, stack): def clear_stack_cache(): for k in tuple(_fun2stack.keys()): del _fun2stack[k] + gc.collect() def get_stack_cache(func): From 3826c548939516015ff138e37566052c5472ccba Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 15:19:00 +0800 Subject: [PATCH 04/21] add `brainpy.math.surrogate..Surrogate` (#638) * add `brainpy.math.Surrogate` * fix * Update _event_matvec.py * Update _event_matvec.py * Update _event_matvec.py --------- Co-authored-by: He Sichao <1310722434@qq.com> --- brainpy/_src/math/surrogate/_one_input_new.py | 27 +++++++++++++++++-- brainpy/math/surrogate.py | 3 ++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/math/surrogate/_one_input_new.py b/brainpy/_src/math/surrogate/_one_input_new.py index 64c7280d..bfffd88f 100644 --- a/brainpy/_src/math/surrogate/_one_input_new.py +++ b/brainpy/_src/math/surrogate/_one_input_new.py @@ -90,7 +90,30 @@ def _as_jax(x): class Surrogate(object): - """The base surrograte gradient function.""" + """The base surrograte gradient function. + + To customize a surrogate gradient function, you can inherit this class and + implement the `surrogate_fun` and `surrogate_grad` methods. + + Examples + -------- + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> import jax.numpy as jnp + + >>> class MySurrogate(bm.Surrogate): + ... def __init__(self, alpha=1.): + ... super().__init__() + ... self.alpha = alpha + ... + ... def surrogate_fun(self, x): + ... return jnp.sin(x) * self.alpha + ... + ... def surrogate_grad(self, x): + ... return jnp.cos(x) * self.alpha + + """ def __call__(self, x): x = _as_jax(x) @@ -123,7 +146,7 @@ def __init__(self, alpha: float = 4.): self.alpha = alpha def surrogate_fun(self, x): - return sci.special.expit(x) + return sci.special.expit(self.alpha * x) def surrogate_grad(self, x): sgax = sci.special.expit(x * self.alpha) diff --git a/brainpy/math/surrogate.py b/brainpy/math/surrogate.py index 0121bdde..bf789743 100644 --- a/brainpy/math/surrogate.py +++ b/brainpy/math/surrogate.py @@ -1,8 +1,9 @@ # -*- coding: utf-8 -*- - from brainpy._src.math.surrogate._one_input_new import ( + Surrogate, + Sigmoid, sigmoid as sigmoid, From 5112f25b4ee1594437b3cf2a40e45c8159574d44 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 1 Mar 2024 15:19:57 +0800 Subject: [PATCH 05/21] Enable brainpy object as pytree so that it can be applied with ``jax.jit`` etc. directly (#625) enable brainpy object as pytree --- brainpy/_src/math/defaults.py | 4 +++ brainpy/_src/math/environment.py | 25 ++++++++++++++++--- brainpy/_src/math/object_transform/base.py | 15 +++++++---- .../math/object_transform/tests/test_base.py | 19 ++++++++++++++ 4 files changed, 55 insertions(+), 8 deletions(-) diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 6ebe9dc2..9f3c5045 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -30,6 +30,9 @@ # '''Default complex data type.''' complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 +# register brainpy object as pytree +bp_object_as_pytree = False + if ti is not None: # '''Default integer data type in Taichi.''' ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 @@ -40,3 +43,4 @@ else: ti_int = None ti_float = None + diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 7827dfed..d49e70f5 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -168,6 +168,7 @@ def __init__( float_: type = None, int_: type = None, bool_: type = None, + bp_object_as_pytree: bool = None, ) -> None: super().__init__() @@ -203,6 +204,10 @@ def __init__( assert isinstance(complex_, type), '"complex_" must a type.' self.old_complex = get_complex() + if bp_object_as_pytree is not None: + assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' + self.old_bp_object_as_pytree = defaults.bp_object_as_pytree + self.dt = dt self.mode = mode self.membrane_scaling = membrane_scaling @@ -211,6 +216,7 @@ def __init__( self.float_ = float_ self.int_ = int_ self.bool_ = bool_ + self.bp_object_as_pytree = bp_object_as_pytree def __enter__(self) -> 'environment': if self.dt is not None: set_dt(self.dt) @@ -221,6 +227,7 @@ def __enter__(self) -> 'environment': if self.int_ is not None: set_int(self.int_) if self.complex_ is not None: set_complex(self.complex_) if self.bool_ is not None: set_bool(self.bool_) + if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: @@ -232,6 +239,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.float_ is not None: set_float(self.old_float) if self.complex_ is not None: set_complex(self.old_complex) if self.bool_ is not None: set_bool(self.old_bool) + if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree def clone(self): return self.__class__(dt=self.dt, @@ -241,7 +249,8 @@ def clone(self): bool_=self.bool_, complex_=self.complex_, float_=self.float_, - int_=self.int_) + int_=self.int_, + bp_object_as_pytree=self.bp_object_as_pytree) def __eq__(self, other): return id(self) == id(other) @@ -269,6 +278,7 @@ def __init__( bool_: type = None, batch_size: int = 1, membrane_scaling: scales.Scaling = None, + bp_object_as_pytree: bool = None, ): super().__init__(dt=dt, x64=x64, @@ -277,7 +287,8 @@ def __init__( int_=int_, bool_=bool_, membrane_scaling=membrane_scaling, - mode=modes.TrainingMode(batch_size)) + mode=modes.TrainingMode(batch_size), + bp_object_as_pytree=bp_object_as_pytree) class batching_environment(environment): @@ -303,6 +314,7 @@ def __init__( bool_: type = None, batch_size: int = 1, membrane_scaling: scales.Scaling = None, + bp_object_as_pytree: bool = None, ): super().__init__(dt=dt, x64=x64, @@ -311,7 +323,8 @@ def __init__( int_=int_, bool_=bool_, mode=modes.BatchingMode(batch_size), - membrane_scaling=membrane_scaling) + membrane_scaling=membrane_scaling, + bp_object_as_pytree=bp_object_as_pytree) def set( @@ -323,6 +336,7 @@ def set( float_: type = None, int_: type = None, bool_: type = None, + bp_object_as_pytree: bool = None, ): """Set the default computation environment. @@ -344,6 +358,8 @@ def set( The integer data type. bool_ The bool data type. + bp_object_as_pytree: bool + Whether to register brainpy object as pytree. """ if dt is not None: assert isinstance(dt, float), '"dt" must a float.' @@ -377,6 +393,9 @@ def set( assert isinstance(complex_, type), '"complex_" must a type.' set_complex(complex_) + if bp_object_as_pytree is not None: + defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree + set_environment = set diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index aaf053ae..53346a7d 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -6,23 +6,24 @@ """ import numbers -import os import warnings from collections import namedtuple from typing import Any, Tuple, Callable, Sequence, Dict, Union, Optional import jax import numpy as np +from jax._src.tree_util import _registry +from jax.tree_util import register_pytree_node_class -from brainpy import errors +from brainpy._src.math.modes import Mode from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) -from brainpy._src.math.modes import Mode from brainpy._src.math.sharding import BATCH_AXIS +from brainpy._src.math import defaults variable_ = None StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) @@ -89,6 +90,10 @@ class BrainPyObject(object): def __init__(self, name=None): super().__init__() + if defaults.bp_object_as_pytree: + if self.__class__ not in _registry: + register_pytree_node_class(self.__class__) + # check whether the object has a unique name. self._name = None self._name = self.unique_name(name=name) @@ -217,8 +222,8 @@ def tree_flatten(self): static_names = [] static_values = [] for k, v in self.__dict__.items(): - # if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)): - if isinstance(v, (BrainPyObject, Variable)): + if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)): + # if isinstance(v, (BrainPyObject, Variable)): dynamic_names.append(k) dynamic_values.append(v) else: diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index 2d640b3b..c6f8f90d 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -231,3 +231,22 @@ def f1(): self.assertTrue(obj.vs['b'] == 12.) self.assertTrue(bm.allclose(obj.vs['c'], bm.ones(10) * 11.)) + +class TestRegisterBPObjectAsPyTree(unittest.TestCase): + def test1(self): + bm.set(bp_object_as_pytree=True) + + hh = bp.dyn.HH(1) + hh.reset() + + tree = jax.tree_structure(hh) + leaves = jax.tree_leaves(hh) + + print(tree) + print(leaves) + print(jax.tree_unflatten(tree, leaves)) + print() + + + + From 26d09c5eeff197ae90befbc91fd049d8c647f9b2 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 2 Mar 2024 19:51:34 +0800 Subject: [PATCH 06/21] Fix CI (#640) * Update requirements-dev-raw.txt * update tests * recovery * upgrade tests * Update CI.yml * Revert "Update CI.yml" This reverts commit 343476ff9186814c757a45a3fd5784066dd0cdec. * update tests * update tests * update CI * update CI * update CI --------- Co-authored-by: He Sichao <1310722434@qq.com> --- .github/workflows/CI.yml | 147 +- brainpy/_src/connect/tests/test_all_time.py | 1272 +++++++++-------- brainpy/_src/dnn/tests/test_linear.py | 34 +- brainpy/_src/dnn/tests/test_mode.py | 2 - .../synapses/tests/test_abstract_synapses.py | 5 + brainpy/_src/math/event/__init__.py | 2 +- .../event/{_csr_matvec.py => csr_matvec.py} | 5 +- .../_src/math/event/tests/test_event_csrmv.py | 100 +- brainpy/_src/math/jitconn/__init__.py | 4 +- .../{_event_matvec.py => event_matvec.py} | 759 +--------- .../math/jitconn/{_matvec.py => matvec.py} | 0 .../math/jitconn/tests/test_event_matvec.py | 913 ++++++------ .../_src/math/jitconn/tests/test_matvec.py | 748 +++++----- .../_src/math/op_register/taichi_aot_based.py | 7 + .../op_register/tests/test_numba_based.py | 8 +- .../op_register/tests/test_taichi_based.py | 8 +- .../tests/test_taichi_clean_cache.py | 16 +- brainpy/_src/math/sparse/__init__.py | 8 +- .../math/sparse/{_bsr_mm.py => bsr_mm.py} | 0 .../math/sparse/{_bsr_mv.py => bsr_mv.py} | 2 +- .../math/sparse/{_coo_mv.py => coo_mv.py} | 0 .../math/sparse/{_csr_mv.py => csr_mv.py} | 3 +- .../math/sparse/{_jax_prim.py => jax_prim.py} | 0 brainpy/_src/math/sparse/tests/test_csrmv.py | 63 +- .../_src/math/sparse/{_utils.py => utils.py} | 0 requirements-dev-raw.txt | 12 - 26 files changed, 1548 insertions(+), 2570 deletions(-) rename brainpy/_src/math/event/{_csr_matvec.py => csr_matvec.py} (99%) rename brainpy/_src/math/jitconn/{_event_matvec.py => event_matvec.py} (56%) rename brainpy/_src/math/jitconn/{_matvec.py => matvec.py} (100%) rename brainpy/_src/math/sparse/{_bsr_mm.py => bsr_mm.py} (100%) rename brainpy/_src/math/sparse/{_bsr_mv.py => bsr_mv.py} (99%) rename brainpy/_src/math/sparse/{_coo_mv.py => coo_mv.py} (100%) rename brainpy/_src/math/sparse/{_csr_mv.py => csr_mv.py} (99%) rename brainpy/_src/math/sparse/{_jax_prim.py => jax_prim.py} (100%) rename brainpy/_src/math/sparse/{_utils.py => utils.py} (100%) delete mode 100644 requirements-dev-raw.txt diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d29b07eb..7f46c959 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -16,12 +16,10 @@ permissions: contents: read # to fetch code actions: write # to cancel previous workflows - -#on: -# push: -# branches: [ master ] -# pull_request: -# branches: [ master ] +# This is what will cancel the workflow +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: @@ -30,14 +28,16 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 + uses: styfle/cancel-workflow-action@0.12.1 with: - access_token: ${{ github.token }} + access_token: ${{ github.token }} - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -53,15 +53,22 @@ jobs: cd brainpy pytest _src/ - test_linux_with_taichi_numba: - runs-on: ubuntu-latest + + test_macos: + runs-on: macos-latest strategy: fail-fast: false matrix: - python-version: [ "3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: @@ -69,109 +76,41 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi + if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - name: Test with pytest run: | cd brainpy pytest _src/ - test_macos: - runs-on: macos-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.10", "3.11"] - - steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Test with pytest - run: | - cd brainpy - pytest -n auto --tb=short _src/ - - test_windows: + runs-on: windows-latest strategy: fail-fast: false matrix: - os: [ win-2019-16core ] - arch: [ AMD64 ] - python-version: ["3.9", "3.10", "3.11"] - runs-on: ${{ matrix.os }} - - steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install -r requirements-dev.txt - pip uninstall brainpy -y - python setup.py install - - name: Test with pytest - run: | - cd brainpy - pytest _src/ - - test_macos_with_taichi_numba: - runs-on: macos-latest - strategy: - fail-fast: false - matrix: - python-version: ["3.9", "3.10", "3.11"] + python-version: [ "3.9", "3.10", "3.11" ] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest taichi numba - if [ -f requirements-dev-raw.txt ]; then pip install -r requirements-dev-raw.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Lint with flake8 - run: | - # stop the build if there are Python syntax errors or undefined names - flake8 brainpy/ --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 brainpy/ --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics - - name: Test with pytest - run: | - cd brainpy - pytest _src/ - + - name: Cancel Previous Runs + uses: styfle/cancel-workflow-action@0.12.1 + with: + access_token: ${{ github.token }} + - uses: actions/checkout@v4 + - name: Print concurrency group + run: echo '${{ github.workflow }}-${{ github.ref }}' + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements-dev.txt + pip uninstall brainpy -y + python setup.py install + - name: Test with pytest + run: | + cd brainpy + pytest _src/ diff --git a/brainpy/_src/connect/tests/test_all_time.py b/brainpy/_src/connect/tests/test_all_time.py index b634d6db..07422f65 100644 --- a/brainpy/_src/connect/tests/test_all_time.py +++ b/brainpy/_src/connect/tests/test_all_time.py @@ -1,18 +1,19 @@ import time from datetime import datetime -import brainpy as bp -import unittest import pytest +import brainpy as bp + +pytest.skip('skip.', allow_module_level=True) + try: - import pandas as pd + import pandas as pd - df = pd.DataFrame( - columns=['connector name', 'superclass', 'connect matrix size', 'build function', 'other parameter', - 'time(ms)']) + df = pd.DataFrame(columns=['connector name', 'superclass', 'connect matrix size', + 'build function', 'other parameter', 'time(ms)']) except (ImportError, ModuleNotFoundError): - print('No pandas installed, skip test.') + print('No pandas installed, skip test.') # size_same = [100, 500, 2500, 12500, 25000, 37500, 50000] # size_same = [100, 500, 2500, 12500] @@ -21,644 +22,645 @@ size_same = [100, 500, 2500] size_diff = [(10, 100), (100, 1000)] + def get_ms(value): - return round(value * 1000, 4) + return round(value * 1000, 4) def insert_row(connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used): - try: - df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + try: + df.loc[len(df)] = [connector_name, superclass, connect_matrix_size, build_function, other_parameter, time_used] + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') class OneEndConnector(unittest.TestCase): - def test_gaussian_prob(self): - print() - for size in size_same: - print('GaussianProb:', size) - conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'sigma=1/include_self=False', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['GaussianProb', - # 'OneEndConnector', - # f'{size}x{size}', - # 'build_coo', - # 'sigma=1/include_self=False', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GaussianProb', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'sigma=1/include_self=False', - time_used) - - def test_grid_four(self): - print() - for size in size_same: - print('GridFour:', size) - conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridFour', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_eight(self): - print() - for size in size_same: - print('GridEight:', size) - conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridEight', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False', - time_used) - - def test_grid_n(self): - print() - for size in size_same: - print('GridN:', size) - conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_mat', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_coo', - 'include_self=False/periodic_boundary=False/N=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('GridN', - 'OneEndConnector', - f'{size}x{size}', - 'build_csr', - 'include_self=False/periodic_boundary=False/N=2', - time_used) + def test_gaussian_prob(self): + print() + for size in size_same: + print('GaussianProb:', size) + conn = bp.connect.GaussianProb(sigma=1., include_self=False, seed=123)(pre_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'sigma=1/include_self=False', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['GaussianProb', + # 'OneEndConnector', + # f'{size}x{size}', + # 'build_coo', + # 'sigma=1/include_self=False', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GaussianProb', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'sigma=1/include_self=False', + time_used) + + def test_grid_four(self): + print() + for size in size_same: + print('GridFour:', size) + conn = bp.connect.GridFour(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridFour', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_eight(self): + print() + for size in size_same: + print('GridEight:', size) + conn = bp.connect.GridEight(include_self=False, periodic_boundary=False)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridEight', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False', + time_used) + + def test_grid_n(self): + print() + for size in size_same: + print('GridN:', size) + conn = bp.connect.GridN(include_self=False, periodic_boundary=False, N=2)(size, size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_mat', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_coo', + 'include_self=False/periodic_boundary=False/N=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('GridN', + 'OneEndConnector', + f'{size}x{size}', + 'build_csr', + 'include_self=False/periodic_boundary=False/N=2', + time_used) class TwoEndConnector(unittest.TestCase): - def test_fixed_prob(self): - print() - for size in size_same: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'prob=0.1', - time_used) - - for size in size_diff: - print('FixedProb:', size) - conn = bp.connect.FixedProb(prob=0.1, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'prob=0.1', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedProb', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'prob=0.1', - time_used) - - def test_fixed_pre_num(self): - print() - for size in size_same: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'pre_num=10', - time_used) - - for size in size_diff: - print('FixedPreNum:', size) - conn = bp.connect.FixedPreNum(num=0.4, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_fixed_post_num(self): - print() - for size in size_same: - print('FixedPostNum:', size) - conn = bp.connect.FixedPostNum(num=10, seed=123) - conn(pre_size=size, post_size=size) - - start = time.time() - mat = conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num=10', - time_used) - - for size in size_diff: - print('FixedPostNum:', size) - conn = bp.connect.FixedPreNum(num=10, seed=123) - conn(pre_size=size[0], post_size=size[1]) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_mat', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_coo', - 'pre_num=10', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('FixedPreNum', - 'TwoEndConnector', - f'{size[0]}x{size[1]}', - 'build_csr', - 'pre_num=10', - time_used) - - def test_prob_dist(self): - print() - for size in size_same: - print('ProbDist:', size) - conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'prob=0.5', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ProbDist', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', - time_used) - - def test_small_world(self): - print() - for size in size_same: - print('SmallWorld:', size) - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('SmallWorld', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'num_neighbor=2/prob=0.5/include_self=False', - time_used) - - def test_scale_free_ba(self): - print() - for size in size_same: - print('ScaleFreeBA:', size) - conn = bp.connect.ScaleFreeBA(m=2) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=2', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBA', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=2', - time_used) - - def test_scale_free_ba_dual(self): - print() - for size in size_same: - print('ScaleFreeBADual:', size) - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm1=2/m2=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('ScaleFreeBADual', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm1=2/m2=3/p=0.4', - time_used) - - def test_power_law(self): - print() - for size in size_same: - print('PowerLaw:', size) - conn = bp.connect.PowerLaw(m=3, p=0.4) - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - 'm=3/p=0.4', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('PowerLaw', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - 'm=3/p=0.4', - time_used) - - def test_one2one(self): - print() - for size in size_same: - print('One2One:', size) - conn = bp.connect.One2One() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - start = time.time() - conn.require(bp.connect.COO) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_coo', - '', - time_used) - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('One2One', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) - - def test_all2all(self): - print() - for size in size_same: - print('All2All:', size) - conn = bp.connect.All2All() - conn(pre_size=size, post_size=size) - - start = time.time() - conn.require(bp.connect.CONN_MAT) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_mat', - '', - time_used) - - # start = time.time() - # conn.require(bp.connect.COO) - # time_used = get_ms(time.time() - start) - # df.loc[len(df)] = ['All2All', - # 'TwoEndConnector', - # f'{size}x{size}', - # 'build_coo', - # '', - # time_used] - - start = time.time() - conn.require(bp.connect.CSR) - time_used = get_ms(time.time() - start) - insert_row('All2All', - 'TwoEndConnector', - f'{size}x{size}', - 'build_csr', - '', - time_used) + def test_fixed_prob(self): + print() + for size in size_same: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'prob=0.1', + time_used) + + for size in size_diff: + print('FixedProb:', size) + conn = bp.connect.FixedProb(prob=0.1, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'prob=0.1', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedProb', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'prob=0.1', + time_used) + + def test_fixed_pre_num(self): + print() + for size in size_same: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'pre_num=10', + time_used) + + for size in size_diff: + print('FixedPreNum:', size) + conn = bp.connect.FixedPreNum(num=0.4, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_fixed_post_num(self): + print() + for size in size_same: + print('FixedPostNum:', size) + conn = bp.connect.FixedPostNum(num=10, seed=123) + conn(pre_size=size, post_size=size) + + start = time.time() + mat = conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num=10', + time_used) + + for size in size_diff: + print('FixedPostNum:', size) + conn = bp.connect.FixedPreNum(num=10, seed=123) + conn(pre_size=size[0], post_size=size[1]) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_mat', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_coo', + 'pre_num=10', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('FixedPreNum', + 'TwoEndConnector', + f'{size[0]}x{size[1]}', + 'build_csr', + 'pre_num=10', + time_used) + + def test_prob_dist(self): + print() + for size in size_same: + print('ProbDist:', size) + conn = bp.connect.ProbDist(dist=1, prob=0.5, pre_ratio=0.3, seed=1234, include_self=True) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'prob=0.5', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ProbDist', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'dist=1|prob=0.5|pre_ratio=0.3|include_self=True', + time_used) + + def test_small_world(self): + print() + for size in size_same: + print('SmallWorld:', size) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('SmallWorld', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'num_neighbor=2/prob=0.5/include_self=False', + time_used) + + def test_scale_free_ba(self): + print() + for size in size_same: + print('ScaleFreeBA:', size) + conn = bp.connect.ScaleFreeBA(m=2) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=2', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBA', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=2', + time_used) + + def test_scale_free_ba_dual(self): + print() + for size in size_same: + print('ScaleFreeBADual:', size) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm1=2/m2=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('ScaleFreeBADual', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm1=2/m2=3/p=0.4', + time_used) + + def test_power_law(self): + print() + for size in size_same: + print('PowerLaw:', size) + conn = bp.connect.PowerLaw(m=3, p=0.4) + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + 'm=3/p=0.4', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('PowerLaw', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + 'm=3/p=0.4', + time_used) + + def test_one2one(self): + print() + for size in size_same: + print('One2One:', size) + conn = bp.connect.One2One() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + start = time.time() + conn.require(bp.connect.COO) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_coo', + '', + time_used) + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('One2One', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) + + def test_all2all(self): + print() + for size in size_same: + print('All2All:', size) + conn = bp.connect.All2All() + conn(pre_size=size, post_size=size) + + start = time.time() + conn.require(bp.connect.CONN_MAT) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_mat', + '', + time_used) + + # start = time.time() + # conn.require(bp.connect.COO) + # time_used = get_ms(time.time() - start) + # df.loc[len(df)] = ['All2All', + # 'TwoEndConnector', + # f'{size}x{size}', + # 'build_coo', + # '', + # time_used] + + start = time.time() + conn.require(bp.connect.CSR) + time_used = get_ms(time.time() - start) + insert_row('All2All', + 'TwoEndConnector', + f'{size}x{size}', + 'build_csr', + '', + time_used) class TestSave(unittest.TestCase): - def test_save(self): - try: - df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', - index=False) - except (NameError, UnboundLocalError): - print('No pandas installed, skip test.') + def test_save(self): + try: + df.to_csv('connector_time_' + datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '.csv', + index=False) + except (NameError, UnboundLocalError): + print('No pandas installed, skip test.') diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 422f161f..6cc44538 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -20,7 +20,7 @@ def __init__(self, *args, **kwargs): size=[(10,), (20, 10), (5, 8, 10)], - num_out=[20, 10, 5] + num_out=[20,] ) def test_Dense1(self, size, num_out): bm.random.seed() @@ -131,8 +131,8 @@ def test_EventCSRLinear(self, conn): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], + prob=[0.1], + weight=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPHomoLinear(self, prob, weight, shape): @@ -144,9 +144,9 @@ def test_JitFPHomoLinear(self, prob, weight, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], + prob=[0.1], + w_low=[-0.01, ], + w_high=[0.01, ], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): @@ -158,9 +158,9 @@ def test_JitFPUniformLinear(self, prob, w_low, w_high, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], + prob=[0.1], + w_mu=[-0.01], + w_sigma=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): @@ -172,8 +172,8 @@ def test_JitFPNormalLinear(self, prob, w_mu, w_sigma, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - weight=[0.01, 0.01], + prob=[0.1], + weight=[0.01,], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPHomoLinear(self, prob, weight, shape): @@ -187,9 +187,9 @@ def test_EventJitFPHomoLinear(self, prob, weight, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.05, 0.5], - w_low=[-0.01, -0.01], - w_high=[0.01, 0.01], + prob=[0.1], + w_low=[-0.01], + w_high=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): @@ -203,9 +203,9 @@ def test_EventJitFPUniformLinear(self, prob, w_low, w_high, shape): bm.clear_buffer_memory() @parameterized.product( - prob=[0.01, 0.1, 0.5], - w_mu=[-0.01, -0.01], - w_sigma=[0.01, 0.01], + prob=[0.1], + w_mu=[-0.01], + w_sigma=[0.01], shape=[(), (10,), (10, 20), (10, 20, 25)] ) def test_EventJitFPNormalLinear(self, prob, w_mu, w_sigma, shape): diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index f0c67da1..10e9eeda 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -4,7 +4,6 @@ import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -63,7 +62,6 @@ def test_Conv2_NonBatching(self): mode=bm.NonBatchingMode()) output = layer(input) bm.clear_buffer_memory() - bm.clear_buffer_memory() @parameterized.product( mode=[bm.TrainingMode(), diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index 6db945ff..d068f207 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -3,9 +3,14 @@ from absl.testing import parameterized +import pytest import brainpy as bp import brainpy.math as bm from brainpy._src.dynold.synapses import abstract_models +from brainpy._src.dependency_check import import_taichi + +if import_taichi(error_if_not_found=False) is None: + pytest.skip('no taichi', allow_module_level=True) class Test_Abstract_Synapse(parameterized.TestCase): diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index bdd3102a..9ebad3e9 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,2 +1,2 @@ -from ._csr_matvec import * +from .csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py similarity index 99% rename from brainpy/_src/math/event/_csr_matvec.py rename to brainpy/_src/math/event/csr_matvec.py index 6b7f7da0..9890838e 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -20,8 +20,8 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy.errors import PackageMissingError __all__ = [ @@ -30,6 +30,7 @@ ti = import_taichi(error_if_not_found=False) + def csrmv( data: Union[float, jax.Array], indices: jax.Array, diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 67e09d0a..6c0a2ed4 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -9,13 +9,11 @@ import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) - seed = 1234 @@ -26,7 +24,6 @@ def func(*args, **kwargs): return func -taichi_csr_matvec = bm.event.csrmv class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): @@ -37,22 +34,22 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo(self, transpose, shape, homo_data): print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (events @ dense) if transpose else (dense @ events) - r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -60,23 +57,22 @@ def test_homo(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo_vmap(self, shape, transpose, homo_data): print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -84,14 +80,14 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'events' f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, + f4 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 @@ -102,16 +98,15 @@ def test_homo_vmap(self, shape, transpose, homo_data): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)], - homo_data=[-1., 0., 1.], + shape=[(100, 200), (10, 1000)], + homo_data=[1.], ) def test_homo_grad(self, shape, transpose, homo_data): print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}') - rng = bm.random.RandomState(seed=seed) + homo_data = bm.asarray([homo_data]) + + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -119,31 +114,26 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(bm.sparse.csrmv))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r1 = jax.grad(sum_op(bm.sparse.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(bm.event.csrmv))(homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, + transpose=transpose) + r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(homo_data, indices, indptr, events.astype(float), shape=shape, + transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) bm.clear_buffer_memory() @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000), ] + shape=[(100, 200), (10, 1000), ] ) def test_heter(self, shape, transpose): print(f'test_heter: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -151,9 +141,9 @@ def test_heter(self, shape, transpose): heter_data = bm.as_jax(rng.random(indices.shape)) r1 = bm.sparse.csrmv(heter_data, indices, indptr, events, + shape=shape, transpose=transpose) + r2 = bm.event.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = taichi_csr_matvec(heter_data, indices, indptr, events, - shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -161,24 +151,21 @@ def test_heter(self, shape, transpose): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] + shape=[(100, 200), (10, 1000)] ) def test_heter_vmap(self, shape, transpose): print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=events, shape=shape, transpose=transpose)) - f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, + f2 = jax.vmap(partial(bm.event.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) @@ -187,16 +174,16 @@ def test_heter_vmap(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) - f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, + f4 = jax.vmap(partial(bm.event.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, + shape=shape, transpose=transpose)) + f6 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) - f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, - shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), @@ -206,15 +193,12 @@ def test_heter_vmap(self, shape, transpose): @parameterized.product( transpose=[True, False], - shape=[(100, 200), - (200, 200), - (200, 100), - (10, 1000)] + shape=[(100, 200), (10, 1000)] ) def test_heter_grad(self, shape, transpose): print(f'test_heter_grad: shape = {shape}, transpose = {transpose}') - rng = bm.random.RandomState(seed=seed) + rng = bm.random.RandomState(seed) indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) @@ -226,20 +210,20 @@ def test_heter_grad(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) r1 = jax.grad(sum_op(bm.sparse.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op(taichi_csr_matvec))( + r2 = jax.grad(sum_op(bm.event.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( + r4 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) diff --git a/brainpy/_src/math/jitconn/__init__.py b/brainpy/_src/math/jitconn/__init__.py index 6f7cddf6..bb6a3c1f 100644 --- a/brainpy/_src/math/jitconn/__init__.py +++ b/brainpy/_src/math/jitconn/__init__.py @@ -1,2 +1,2 @@ -from ._matvec import * -from ._event_matvec import * +from .matvec import * +from .event_matvec import * diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py similarity index 56% rename from brainpy/_src/math/jitconn/_event_matvec.py rename to brainpy/_src/math/jitconn/event_matvec.py index ac62bbfa..27998038 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -8,17 +8,17 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.jitconn._matvec import (mv_prob_homo, - mv_prob_uniform, - mv_prob_normal, - _general_checking, - raw_mv_prob_homo, - raw_mv_prob_uniform, - raw_mv_prob_normal, - _mv_prob_homo_transpose, - _mv_prob_uniform_transpose, - _mv_prob_normal_transpose, - _reverse) +from brainpy._src.math.jitconn.matvec import (mv_prob_homo, + mv_prob_uniform, + mv_prob_normal, + _general_checking, + raw_mv_prob_homo, + raw_mv_prob_uniform, + raw_mv_prob_normal, + _mv_prob_homo_transpose, + _mv_prob_uniform_transpose, + _mv_prob_normal_transpose, + _reverse) from brainpy._src.math.ndarray import _get_dtype from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError @@ -45,743 +45,6 @@ def event_mv_prob_homo( if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - -event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -### BRAINPYLIB ### - -def event_mv_prob_homo_brainpylib( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - weight = jnp.atleast_1d(jnp.asarray(weight)) - conn_prob = jnp.atleast_1d(jnp.asarray(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - return r - - -event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform_brainpylib( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal_brainpylib( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ - - -def _event_matvec_prob_homo_abstract( - events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_homo_cpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_homo' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_homo_gpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_homo_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1], ) - - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, weight, clen, seed = primals - event_dot, weight_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(weight_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(event_dot) is ad.Zero: - raise ValueError - dr = mv_prob_homo_p.bind(event_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(event_dot) is ad.Zero: - dr = mv_prob_homo_p.bind(events, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - dr = mv_prob_homo_p.bind(event_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, dr - - -def _event_matvec_prob_homo_transpose( - ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -event_mv_prob_homo_p = Primitive('event_mv_prob_homo') -event_mv_prob_homo_p.multiple_results = True -event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) -event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation -ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp -ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose -register_general_batching(event_mv_prob_homo_p) - - -def _event_matvec_prob_uniform_abstract( - events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_uniform_cpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_uniform_gpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_low, w_high, clen, seed = primals - events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - outdim_parallel=outdim_parallel, - transpose=transpose) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(events_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_uniform_transpose( - ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') -event_mv_prob_uniform_p.multiple_results = True -event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) -event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation -register_general_batching(event_mv_prob_uniform_p) -ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp -ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose - - -def _event_matvec_prob_normal_abstract( - events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - assert w_mu.dtype == w_sigma.dtype - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _get_types(event_shape): - event_type = event_shape.element_type() - if event_type == jnp.bool_: - event_type = b'_bool' - out_dtype = dtypes.canonicalize_dtype(float) - elif event_type == jnp.float32: - event_type = b'_float' - out_dtype = event_shape.element_type() - elif event_type == jnp.float64: - event_type = b'_double' - out_dtype = event_shape.element_type() - else: - raise TypeError - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - return out_dtype, event_type, type_name - - -def _event_matvec_prob_normal_cpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_normal' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_normal_gpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_normal_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_mu, w_sigma, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_mu, w_sigma, clen, seed = primals - events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(events_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_normal_transpose( - ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -event_mv_prob_normal_p = Primitive('event_mv_prob_normal') -event_mv_prob_normal_p.multiple_results = True -event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) -event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation -register_general_batching(event_mv_prob_normal_p) -ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp -ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose - - -### TAICHI ### - -def event_mv_prob_homo_taichi( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - events: Array, ndarray - The events. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ events = as_jax(events) weight = as_jax(weight) if jnp.ndim(weight) < 1: diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/matvec.py similarity index 100% rename from brainpy/_src/math/jitconn/_matvec.py rename to brainpy/_src/math/jitconn/matvec.py diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index d8e08654..6fb8d02e 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from functools import partial import jax import jax.numpy as jnp -from absl.testing import parameterized import pytest +from absl.testing import parameterized import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -12,515 +11,419 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) - shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - -taichi_mv_prob_homo = bm.jitconn.event_mv_prob_homo -taichi_mv_prob_uniform = bm.jitconn.event_mv_prob_uniform -taichi_mv_prob_normal = bm.jitconn.event_mv_prob_normal +shapes = [(100, 200), (1000, 10)] class Test_event_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - homo_data=[-1., ], - bool_event=[True, False], - seed=[1234], - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): - print(f'_test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_homo(events, - homo_data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') - # indices = bm.as_jax(indices) - # indptr = bm.as_jax(indptr) - # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, - # shape=shape, transpose=transpose) - # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - x64=[True, False], - outdim_parallel=[True, False], - shape=shapes, - prob=[0.01, 0.1, 0.5], - bool_event=[True, False], - seed=[1234], + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_event_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + homo_data=[-1.], + bool_event=[True, False], + seed=[1234], + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, bool_event=True, seed=1234, x64=False): + print(f'_test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_homo(events, + homo_data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + # indices, indptr = bp.conn.FixedProb(prob)(*shape).require('pre2post') + # indices = bm.as_jax(indices) + # indptr = bm.as_jax(indptr) + # r3 = event_ops.event_csr_matvec(homo_data, indices, indptr, events, + # shape=shape, transpose=transpose) + # print('Homo difference: ', bm.abs(r1 - r3).sum() / r1.size) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + bool_event=[True, False], + seed=[1234], + ) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): + print(f'_test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'bool_event = {bool_event}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: bm.jitconn.event_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + transpose=transpose, outdim_parallel=outdim_parallel + )[0] ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, bool_event=True, seed=1234, x64=False): - print(f'_test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - transpose=transpose, outdim_parallel=outdim_parallel - )[0] - ) - r1 = f1(events, weights) - r1 = jax.block_until_ready(r1) - r2 = f1(events, weights) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.5] + r1 = f1(events, weights) + r1 = jax.block_until_ready(r1) + r2 = f1(events, weights) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.5 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: bm.jitconn.event_mv_prob_homo( + event, data, conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), + argnums=0 ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.5 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose)[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - - r3 = f1(events, 3.) - r3 = jax.block_until_ready(r3) - - self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'bool_event = {bool_event}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - bool_event=bool_event, - seed=1234, - x64=x64 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, 0.4] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] - for bool_event in [True, False] + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + + r3 = f1(events, 3.) + r3 = jax.block_until_ready(r3) + + self.assertTrue(jnp.allclose(r1 * 3., r3, atol=1e-6)) + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[-1.], + w_high=[1.], + bool_event=[True, False] + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_uniform(events, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + bool_event=[True, False], + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap( + lambda e: bm.jitconn.event_mv_prob_uniform(e, + w_low=0., + w_high=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_uniform(events, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel, prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_uniform_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'bool_event={bool_event}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap( - lambda e: taichi_mv_prob_uniform(e, - w_low=0., - w_high=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - ) - - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - testcase_name=f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_high: taichi_mv_prob_uniform( - e, - w_low=0., - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - # print(r1) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu={w_mu}, ' - f'w_sigma={w_sigma}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1, ] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] - for bool_event in [True, False] - ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal: shape = {shape}, ' - f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' - f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - r1 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r1 = jax.block_until_ready(r1) - - r2 = taichi_mv_prob_normal(events, - w_mu=w_mu, - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - bool_event=bool_event, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_vmap: ' - f'shape={shape}, ' - f'transpose={transpose}, ' - f'outdim_parallel={outdim_parallel}, ' - f'prob={prob}, ' - f'bool_event={bool_event}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for bool_event in [True, False] + + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_high: bm.jitconn.event_mv_prob_uniform( + e, + w_low=0., + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, - bool_event=True, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 - events = bm.as_jax(events) - if not bool_event: - events = events.astype(float) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, - w_mu=0., - w_sigma=1., - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose)) - r1 = f1(events) - r1 = jax.block_until_ready(r1) - r2 = f1(events) - r2 = jax.block_until_ready(r2) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - x64=x64, - seed=1234, - testcase_name=f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + # print(r1) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1, ], + w_mu=[0.], + w_sigma=[0.1], + bool_event=[True, False], + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal: shape = {shape}, ' + f'transpose = {transpose}, outdim_parallel = {outdim_parallel}, prob={prob}, ' + f'w_mu = {w_mu}, w_sigma = {w_sigma}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + r1 = bm.jitconn.event_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r1 = jax.block_until_ready(r1) + + r2 = bm.jitconn.event_mv_prob_normal(events, + w_mu=w_mu, + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose = [True, False], + x64 = [True, False], + outdim_parallel = [True, False], + shape = shapes, + prob = [0.1], + bool_event = [True, False], + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, + bool_event=True, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random((10, shape[0] if transpose else shape[1])) < 0.1 + events = bm.as_jax(events) + if not bool_event: + events = events.astype(float) + + f1 = jax.vmap(lambda e: bm.jitconn.event_mv_prob_normal(e, + w_mu=0., + w_sigma=1., + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose)) + r1 = f1(events) + r1 = jax.block_until_ready(r1) + r2 = f1(events) + r2 = jax.block_until_ready(r2) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose = [True, False], + x64 = [True, False], + outdim_parallel = [True, False], + shape = shapes, + prob = [0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, x64={x64}') + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = rng.random(shape[0] if transpose else shape[1]) < 0.1 + events = bm.as_jax(events) + events = events.astype(float) + + f1 = jax.jit( + jax.grad( + lambda e, w_sigma: bm.jitconn.event_mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose).sum() + ) ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}') - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = rng.random(shape[0] if transpose else shape[1]) < 0.1 - events = bm.as_jax(events) - events = events.astype(float) - - f1 = jax.jit( - jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose).sum() - ) - ) - r1 = f1(events, 1.) - r1 = jax.block_until_ready(r1) - r2 = f1(events, 2.) - r2 = jax.block_until_ready(r2) - self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + r1 = f1(events, 1.) + r1 = jax.block_until_ready(r1) + r2 = f1(events, 2.) + r2 = jax.block_until_ready(r2) + self.assertTrue(bm.allclose(r1 * 2, r2, atol=1e-6)) + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 8a0ae444..67c18124 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -1,10 +1,9 @@ # -*- coding: utf-8 -*- -from functools import partial import jax import jax.numpy as jnp -from absl.testing import parameterized import pytest +from absl.testing import parameterized import brainpy.math as bm from brainpy._src.dependency_check import import_taichi @@ -12,55 +11,38 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] -shapes = [(100, 200), (2, 1000), (1000, 2)] - -taichi_mv_prob_homo = bm.jitconn.mv_prob_homo -taichi_mv_prob_uniform = bm.jitconn.mv_prob_uniform -taichi_mv_prob_normal = bm.jitconn.mv_prob_normal +shapes = [(100, 200), (1000, 10)] class Test_matvec_prob_conn(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) - bm.set_platform(platform) - print() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}, ' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - homo_data=homo_data, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for homo_data in [-1., 1.] - ) - def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): - print(f'test_homo: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'homo_data = {homo_data}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_homo(vector, + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_matvec_prob_conn, self).__init__(*args, **kwargs) + bm.set_platform(platform) + print() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + homo_data=[1.] + ) + def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=1234, x64=False): + print(f'test_homo: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'homo_data = {homo_data}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, @@ -68,152 +50,118 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=123 outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_homo(vector, + r2 = bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_homo_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - weights = bm.as_jax(rng.random(10)) - - f1 = jax.vmap( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, shape=shape, seed=seed, - outdim_parallel=outdim_parallel, transpose=transpose - )[0] - ) - r1 = f1(events, weights) - r2 = f1(events, weights) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_homo_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_homo_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + weights = bm.as_jax(rng.random(10)) + + f1 = jax.vmap( + lambda event, data: bm.jitconn.mv_prob_homo( + event, data, + conn_prob=prob, shape=shape, seed=seed, + outdim_parallel=outdim_parallel, transpose=transpose + )[0] ) - def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_homo_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 - events = events.astype(float) - - f1 = jax.grad( - lambda event, data: taichi_mv_prob_homo( - event, data, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum(), - argnums=0 - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - - self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}' - f'x64 = {x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_low=w_low, - w_high=w_high, - x64=x64, - seed=1234) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_low, w_high in [(-1., 0.), (0., 1.), (-1., 1.)] + r1 = f1(events, weights) + r2 = f1(events, weights) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_homo_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.5 + events = events.astype(float) + + f1 = jax.grad( + lambda event, data: bm.jitconn.mv_prob_homo( + event, data, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum(), + argnums=0 ) - def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): - print(f'test_uniform: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_low = {w_low}, ' - f'w_high = {w_high}, ' - f'x64 = {x64}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_uniform(events, + r1 = f1(events, 1.) + r2 = f1(events, 2.) + + self.assertTrue(jnp.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_low=[-0.1], + w_high=[1.0], + ) + def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, seed=1234, x64=False): + print(f'test_uniform: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_low = {w_low}, ' + f'w_high = {w_high}, ' + f'x64 = {x64}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -222,7 +170,7 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_uniform(events, + r2 = bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=prob, @@ -230,45 +178,35 @@ def test_uniform(self, shape, transpose, outdim_parallel, prob, w_low, w_high, s seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_uniform_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'test_uniform_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_uniform(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'test_uniform_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_uniform(e, w_low=0., w_high=1., conn_prob=prob, @@ -277,107 +215,81 @@ def test_uniform_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=(f'test_uniform_grad, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64) - for x64 in [True, False] - for transpose in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_uniform_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - f1 = jax.grad( - lambda e, w_low, w_high: taichi_mv_prob_uniform( - e, - w_low=w_low, - w_high=w_high, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - - r1 = f1(events, 0., 1.) - r2 = f1(events, 0., 2.) - - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict( - testcase_name=(f'test_normal, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma},' - f'x64={x64}'), - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - w_mu=w_mu, - w_sigma=w_sigma, - seed=1234 - ) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - for w_mu, w_sigma in [(-1., 1.), (0., 0.1), (0., 0.5)] + r1 = f1(events) + r2 = f1(events) + self.assertTrue(jnp.allclose(r1, r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + x64=[True, False], + transpose=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + ) + def test_uniform_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_uniform_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + f1 = jax.grad( + lambda e, w_low, w_high: bm.jitconn.mv_prob_uniform( + e, + w_low=w_low, + w_high=w_high, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() ) - def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): - print(f'_test_normal: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'w_mu = {w_mu}, ' - f'w_sigma = {w_sigma}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - r1 = taichi_mv_prob_normal(events, + + r1 = f1(events, 0., 1.) + r2 = f1(events, 0., 2.) + + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1], + w_mu=[0.], + w_sigma=[0.2] + ) + def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, seed=1234, x64=False): + print(f'_test_normal: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}, ' + f'w_mu = {w_mu}, ' + f'w_sigma = {w_sigma}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + r1 = bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -386,7 +298,7 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se outdim_parallel=outdim_parallel, transpose=transpose) - r2 = taichi_mv_prob_normal(events, + r2 = bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=prob, @@ -394,46 +306,36 @@ def test_normal(self, shape, transpose, outdim_parallel, prob, w_mu, w_sigma, se seed=seed, outdim_parallel=outdim_parallel, transpose=transpose) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(testcase_name=f'test_normal_vmap, shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}', - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234) - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] - ) - def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_vmap: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) - - f1 = jax.vmap(lambda e: taichi_mv_prob_normal(e, + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_vmap: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + + rng = bm.random.RandomState() + events = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) + + f1 = jax.vmap(lambda e: bm.jitconn.mv_prob_normal(e, w_mu=0., w_sigma=1., conn_prob=prob, @@ -441,66 +343,54 @@ def test_normal_vmap(self, shape, transpose, outdim_parallel, prob, seed=1234, x seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - r1 = f1(events) - r2 = f1(events) - c = jnp.allclose(r1, r2, atol=1e-6) - if not c: - print(r1, r2) - print(r1 - r2) - self.assertTrue(c) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() - - @parameterized.named_parameters( - dict(shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - prob=prob, - seed=1234, - x64=x64, - testcase_name=f'test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}, ' - f'x64={x64}') - for transpose in [True, False] - for x64 in [True, False] - for outdim_parallel in [True, False] - for shape in shapes - for prob in [0.01, 0.1] + r1 = f1(events) + r2 = f1(events) + c = jnp.allclose(r1, r2, atol=1e-6) + if not c: + print(r1, r2) + print(r1 - r2) + self.assertTrue(c) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + x64=[True, False], + outdim_parallel=[True, False], + shape=shapes, + prob=[0.1] + ) + def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): + print(f'_test_normal_grad: ' + f'shape = {shape}, ' + f'transpose = {transpose}, ' + f'outdim_parallel = {outdim_parallel}, ' + f'prob={prob}') + + if x64: + bm.enable_x64() + rng = bm.random.RandomState() + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + events = events.astype(float) + + f1 = jax.grad( + lambda e, w_sigma: bm.jitconn.mv_prob_normal( + e, + w_mu=0., + w_sigma=w_sigma, + conn_prob=prob, + shape=shape, + seed=seed, + outdim_parallel=outdim_parallel, + transpose=transpose + )[0].sum() ) - def test_normal_grad(self, shape, transpose, outdim_parallel, prob, seed=1234, x64=False): - print(f'_test_normal_grad: ' - f'shape = {shape}, ' - f'transpose = {transpose}, ' - f'outdim_parallel = {outdim_parallel}, ' - f'prob={prob}') - - if x64: - bm.enable_x64() - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - events = events.astype(float) - - f1 = jax.grad( - lambda e, w_sigma: taichi_mv_prob_normal( - e, - w_mu=0., - w_sigma=w_sigma, - conn_prob=prob, - shape=shape, - seed=seed, - outdim_parallel=outdim_parallel, - transpose=transpose - )[0].sum() - ) - r1 = f1(events, 1.) - r2 = f1(events, 2.) - self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) - - if x64: - bm.disable_x64() - bm.clear_buffer_memory() + r1 = f1(events, 1.) + r2 = f1(events, 2.) + self.assertTrue(bm.allclose(r1 * 2., r2, atol=1e-6)) + + if x64: + bm.disable_x64() + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 7fac4452..f9328906 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -16,6 +16,7 @@ from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call +from brainpy.errors import PackageMissingError from brainpy._src.dependency_check import (import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops) @@ -485,10 +486,16 @@ def _taichi_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): def register_taichi_aot_mlir_cpu_translation_rule(primitive, cpu_kernel): + if import_taichi(error_if_not_found=False) is None: + raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') + rule = partial(_taichi_mlir_cpu_translation_rule, cpu_kernel) mlir.register_lowering(primitive, rule, platform='cpu') def register_taichi_aot_mlir_gpu_translation_rule(primitive, gpu_kernel): + if import_taichi(error_if_not_found=False) is None: + raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') + rule = partial(_taichi_mlir_gpu_translation_rule, gpu_kernel) mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py index dc093f62..28b80d0f 100644 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ b/brainpy/_src/math/op_register/tests/test_numba_based.py @@ -1,14 +1,16 @@ -import pytest import jax.core -import brainpy.math as bm +import pytest +import brainpy.math as bm from brainpy._src.dependency_check import import_numba + numba = import_numba(error_if_not_found=False) if numba is None: pytest.skip('no numba', allow_module_level=True) bm.set_platform('cpu') + @numba.njit(fastmath=True) def numba_event_csrmv(weight, indices, vector, outs): outs.fill(0) @@ -33,5 +35,3 @@ def test_event_ELL(): call(1000) call(100) bm.clear_buffer_memory() - - diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 4db38fbc..199dce98 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -1,10 +1,10 @@ -import pytest import jax import jax.numpy as jnp +import pytest import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi + ti = import_taichi(error_if_not_found=False) if ti is None: pytest.skip('no taichi', allow_module_level=True) @@ -35,6 +35,7 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) + @ti.kernel def event_ell_gpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), @@ -47,12 +48,13 @@ def event_ell_gpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) + prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu) def test_taichi_op_register(): s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) + indices = bm.random.randint(0, s, (s, 100)) vector = bm.random.rand(s) < 0.1 weight = bm.array([1.0]) diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 51c964b2..5b27b2fd 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -1,16 +1,15 @@ -import brainpy.math as bm import jax import jax.numpy as jnp -import platform -import pytest + +import brainpy.math as bm +import taichi as ti from brainpy._src.dependency_check import import_taichi ti = import_taichi(error_if_not_found=False) if ti is None: + import pytest pytest.skip('no taichi', allow_module_level=True) -if not platform.platform().startswith('Windows'): - pytest.skip(allow_module_level=True) @ti.func def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: @@ -21,6 +20,7 @@ def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): out[index] += weight_val + @ti.kernel def event_ell_cpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), @@ -34,11 +34,13 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) + prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) + def test_taichi_clean_cache(): s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) + indices = bm.random.randint(0, s, (s, 100)) vector = bm.random.rand(s) < 0.1 weight = bm.array([1.0]) @@ -55,4 +57,4 @@ def test_taichi_clean_cache(): print('kernels: ', bm.check_kernels_count()) -# test_taichi_clean_cache() \ No newline at end of file +# test_taichi_clean_cache() diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d5353324..14256cbc 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,8 +1,8 @@ # from ._coo_mv import * # from ._bsr_mv import * -from ._csr_mv import * -from ._utils import * -from ._bsr_mm import * -from ._jax_prim import * +from .csr_mv import * +from .utils import * +from .bsr_mm import * +from .jax_prim import * diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/bsr_mm.py similarity index 100% rename from brainpy/_src/math/sparse/_bsr_mm.py rename to brainpy/_src/math/sparse/bsr_mm.py diff --git a/brainpy/_src/math/sparse/_bsr_mv.py b/brainpy/_src/math/sparse/bsr_mv.py similarity index 99% rename from brainpy/_src/math/sparse/_bsr_mv.py rename to brainpy/_src/math/sparse/bsr_mv.py index a35895bc..7dc0b683 100644 --- a/brainpy/_src/math/sparse/_bsr_mv.py +++ b/brainpy/_src/math/sparse/bsr_mv.py @@ -11,7 +11,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, register_general_batching) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy._src.dependency_check import import_brainpylib_gpu_ops from brainpy.errors import GPUOperatorNotFound diff --git a/brainpy/_src/math/sparse/_coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py similarity index 100% rename from brainpy/_src/math/sparse/_coo_mv.py rename to brainpy/_src/math/sparse/coo_mv.py diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py similarity index 99% rename from brainpy/_src/math/sparse/_csr_mv.py rename to brainpy/_src/math/sparse/csr_mv.py index 42969f43..6eaf6b79 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/csr_mv.py @@ -13,7 +13,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo from brainpy.errors import PackageMissingError ti = import_taichi(error_if_not_found=False) @@ -108,6 +108,7 @@ def raw_csrmv_taichi( ): if ti is None: raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + out_shape = shape[1] if transpose else shape[0] if data.shape[0] != 1: if bm.get_platform() == 'gpu': diff --git a/brainpy/_src/math/sparse/_jax_prim.py b/brainpy/_src/math/sparse/jax_prim.py similarity index 100% rename from brainpy/_src/math/sparse/_jax_prim.py rename to brainpy/_src/math/sparse/jax_prim.py diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index ec448e65..40bcbb70 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -3,12 +3,11 @@ from functools import partial import jax +import pytest from absl.testing import parameterized -import pytest import brainpy as bp import brainpy.math as bm - from brainpy._src.dependency_check import import_taichi if import_taichi(error_if_not_found=False) is None: @@ -25,7 +24,6 @@ def func(*args, **kwargs): return func - def compare_with_nan_tolerance(a, b, tol=1e-8): """ Compare two arrays with tolerance for NaN values. @@ -56,8 +54,6 @@ def compare_with_nan_tolerance(a, b, tol=1e-8): return bm.allclose(a_non_nan, b_non_nan, atol=tol) -taichi_csr_matvec = bm.sparse.csrmv - class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_csrmv_taichi, self).__init__(*args, **kwargs) @@ -67,8 +63,8 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] + shape=[(200, 200), (10, 1000)], + homo_data=[1.] ) def test_homo(self, transpose, shape, homo_data): print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') @@ -87,15 +83,15 @@ def test_homo(self, transpose, shape, homo_data): dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = bm.sparse.csrmv(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (100, 1000), (2, 2000)], - v=[-1., 0., 1.] + shape=[(200, 200), (100, 1000)], + v=[1.] ) def test_homo_vmap(self, transpose, shape, v): print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}') @@ -113,7 +109,7 @@ def test_homo_vmap(self, transpose, shape, v): dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) - f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(dense_data) r2 = jax.vmap(f2)(homo_data) @@ -123,8 +119,8 @@ def test_homo_vmap(self, transpose, shape, v): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)], - homo_data=[-1., 0., 1.] + shape=[(200, 200), (10, 1000)], + homo_data=[1.] ) def test_homo_grad(self, transpose, shape, homo_data): print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') @@ -148,8 +144,7 @@ def test_homo_grad(self, transpose, shape, homo_data): ((dense * a) @ vector).sum()), argnums=0) r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(taichi_csr_matvec))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = jax.grad(sum_op(bm.sparse.csrmv))(bm.asarray([homo_data]), indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -158,8 +153,8 @@ def test_homo_grad(self, transpose, shape, homo_data): dense_data = dense * homo_data dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) r3 = dense_f2(vector) - r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r4 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( + bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -168,8 +163,8 @@ def test_homo_grad(self, transpose, shape, homo_data): ((dense * a) @ v).sum()), argnums=(0, 1)) r5 = dense_f3(homo_data, vector) - r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( + bm.asarray([homo_data]), indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) @@ -177,7 +172,7 @@ def test_homo_grad(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (2, 2000)], + shape=[(200, 200), (2, 2000)], ) def test_heter(self, transpose, shape): print(f'test_homo: transpose = {transpose} shape = {shape}') @@ -196,7 +191,7 @@ def test_heter(self, transpose, shape): dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) r1 = (vector @ dense) if transpose else (dense @ vector) - r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + r2 = bm.sparse.csrmv(heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(compare_with_nan_tolerance(r1, r2)) @@ -204,7 +199,7 @@ def test_heter(self, transpose, shape): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + shape=[(200, 200), (2, 2000)] ) def test_heter_vmap(self, transpose, shape): rng = bm.random.RandomState(seed=seed) @@ -222,7 +217,7 @@ def test_heter_vmap(self, transpose, shape): shape=shape))(heter_data) f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) - f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, + f2 = partial(bm.sparse.csrmv, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(dense_data) r2 = jax.vmap(f2)(heter_data) @@ -230,7 +225,7 @@ def test_heter_vmap(self, transpose, shape): @parameterized.product( transpose=[True, False], - shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)] + shape=[(200, 200), (2, 2000)] ) def test_heter_grad(self, transpose, shape): rng = bm.random.RandomState(seed=seed) @@ -246,11 +241,10 @@ def test_heter_grad(self, transpose, shape): vector = bm.as_jax(vector) # grad 'data' - dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), - argnums=0) - csr_f1 = jax.grad(lambda a: taichi_csr_matvec(a, indices, indptr, vector, - shape=shape, - transpose=transpose).sum(), + dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), argnums=0) + csr_f1 = jax.grad(lambda a: bm.sparse.csrmv(a, indices, indptr, vector, + shape=shape, + transpose=transpose).sum(), argnums=0) r1 = csr_f1(heter_data) r2 = dense_f1(dense_data) @@ -260,12 +254,11 @@ def test_heter_grad(self, transpose, shape): self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' - dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), - argnums=0) - csr_f2 = jax.grad(lambda v: taichi_csr_matvec(heter_data, indices, indptr, v, - shape=shape, - transpose=transpose).sum(), - argnums=0) + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), argnums=0) + csr_f2 = jax.grad(lambda v: bm.sparse.csrmv(heter_data, indices, indptr, v, + shape=shape, + transpose=transpose).sum(), + argnums=0) r3 = dense_f2(vector) r4 = csr_f2(vector) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/_src/math/sparse/_utils.py b/brainpy/_src/math/sparse/utils.py similarity index 100% rename from brainpy/_src/math/sparse/_utils.py rename to brainpy/_src/math/sparse/utils.py diff --git a/requirements-dev-raw.txt b/requirements-dev-raw.txt deleted file mode 100644 index 99361efa..00000000 --- a/requirements-dev-raw.txt +++ /dev/null @@ -1,12 +0,0 @@ -numpy -jax -jaxlib -matplotlib -msgpack -tqdm -pathos - - -# test requirements -pytest -absl-py From 95706178305203f240817ffca32bf03c115e2232 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 3 Mar 2024 10:43:05 +0800 Subject: [PATCH 07/21] Clean taichi AOT caches, Enable to set ``numpy_func_return`` (#643) * rename `clean_caches` to `clear_taichi_aot_caches`, `check_kernels_count` to `count_taichi_aot_kernels` * fix windows CI * fix windows CI * test object load and save states * enable to set `numpy_func_return = True/False` * fix random dtype inconsistency * fix windows tests * fix `numpy_func_return` setting --- brainpy/_src/_delay.py | 2 +- .../tests/test_random_conn_visualize.py | 214 +++++++-------- brainpy/_src/math/__init__.py | 1 - brainpy/_src/math/compat_numpy.py | 28 +- brainpy/_src/math/compat_tensorflow.py | 56 ++-- brainpy/_src/math/defaults.py | 21 +- brainpy/_src/math/environment.py | 28 +- .../_src/math/event/tests/test_event_csrmv.py | 6 + .../math/jitconn/tests/test_event_matvec.py | 7 +- .../_src/math/jitconn/tests/test_matvec.py | 6 + brainpy/_src/math/ndarray.py | 21 +- brainpy/_src/math/object_transform/base.py | 2 +- .../math/object_transform/tests/test_base.py | 50 +++- brainpy/_src/math/op_register/__init__.py | 2 +- .../_src/math/op_register/taichi_aot_based.py | 76 ++++-- .../tests/test_taichi_clean_cache.py | 6 +- brainpy/_src/math/others.py | 4 +- brainpy/_src/math/random.py | 39 +-- brainpy/_src/math/sparse/tests/test_csrmv.py | 7 +- brainpy/_src/math/surrogate/_compt.py | 247 ------------------ brainpy/_src/math/tests/test_environment.py | 15 ++ brainpy/_src/math/tests/test_random.py | 4 + brainpy/math/__init__.py | 7 - brainpy/math/op_register.py | 4 +- docs/apis/brainpy.math.op_register.rst | 32 ++- examples/dynamics_simulation/ei_nets.py | 4 +- 26 files changed, 406 insertions(+), 483 deletions(-) delete mode 100644 brainpy/_src/math/surrogate/_compt.py create mode 100644 brainpy/_src/math/tests/test_environment.py diff --git a/brainpy/_src/_delay.py b/brainpy/_src/_delay.py index a646fd15..bac73e01 100644 --- a/brainpy/_src/_delay.py +++ b/brainpy/_src/_delay.py @@ -144,7 +144,7 @@ def register_entry( delay_type = 'homo' else: delay_type = 'heter' - delay_step = bm.Array(delay_step) + delay_step = delay_step elif callable(delay_step): delay_step = delay_step(self.delay_target_shape) delay_type = 'heter' diff --git a/brainpy/_src/connect/tests/test_random_conn_visualize.py b/brainpy/_src/connect/tests/test_random_conn_visualize.py index 9cd64821..ba0d95f1 100644 --- a/brainpy/_src/connect/tests/test_random_conn_visualize.py +++ b/brainpy/_src/connect/tests/test_random_conn_visualize.py @@ -2,176 +2,178 @@ import pytest +pytest.skip('skip', allow_module_level=True) + import brainpy as bp def test_random_fix_pre1(): - for num in [0.4, 20]: - conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat1 = conn1.require(bp.connect.CONN_MAT) + for num in [0.4, 20]: + conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat1 = conn1.require(bp.connect.CONN_MAT) - conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat2 = conn2.require(bp.connect.CONN_MAT) + conn2 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat2 = conn2.require(bp.connect.CONN_MAT) - print() - print(f'num = {num}') - print('conn_mat 1\n', mat1) - print(mat1.sum()) - print('conn_mat 2\n', mat2) - print(mat2.sum()) + print() + print(f'num = {num}') + print('conn_mat 1\n', mat1) + print(mat1.sum()) + print('conn_mat 2\n', mat2) + print(mat2.sum()) - assert bp.math.array_equal(mat1, mat2) - bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) + assert bp.math.array_equal(mat1, mat2) + bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) def test_random_fix_pre2(): - for num in [0.5, 3]: - conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4) - mat1 = conn1.require(bp.connect.CONN_MAT) - print() - print(mat1) + for num in [0.5, 3]: + conn1 = bp.connect.FixedPreNum(num, seed=1234)(pre_size=5, post_size=4) + mat1 = conn1.require(bp.connect.CONN_MAT) + print() + print(mat1) - bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num) + bp.connect.visualizeMat(mat1, 'FixedPreNum: num=%s pre_size=5, post_size=4' % num) def test_random_fix_pre3(): - with pytest.raises(bp.errors.ConnectorError): - conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4) - conn1.require(bp.connect.CONN_MAT) + with pytest.raises(bp.errors.ConnectorError): + conn1 = bp.connect.FixedPreNum(num=6, seed=1234)(pre_size=3, post_size=4) + conn1.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4') + bp.connect.visualizeMat(conn1, 'FixedPreNum: num=6, pre_size=3, post_size=4') def test_random_fix_post1(): - for num in [0.4, 20]: - conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat1 = conn1.require(bp.connect.CONN_MAT) + for num in [0.4, 20]: + conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat1 = conn1.require(bp.connect.CONN_MAT) - conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) - mat2 = conn2.require(bp.connect.CONN_MAT) + conn2 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=(10, 15), post_size=(10, 20)) + mat2 = conn2.require(bp.connect.CONN_MAT) - print() - print('conn_mat 1\n', mat1) - print('conn_mat 2\n', mat2) + print() + print('conn_mat 1\n', mat1) + print('conn_mat 2\n', mat2) - assert bp.math.array_equal(mat1, mat2) - bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) + assert bp.math.array_equal(mat1, mat2) + bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=(10, 15), post_size=(10, 20)' % num) def test_random_fix_post2(): - for num in [0.5, 3]: - conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4) - mat1 = conn1.require(bp.connect.CONN_MAT) - print(mat1) - bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num) + for num in [0.5, 3]: + conn1 = bp.connect.FixedPostNum(num, seed=1234)(pre_size=5, post_size=4) + mat1 = conn1.require(bp.connect.CONN_MAT) + print(mat1) + bp.connect.visualizeMat(mat1, 'FixedPostNum: num=%s pre_size=5, post_size=4' % num) def test_random_fix_post3(): - with pytest.raises(bp.errors.ConnectorError): - conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) - conn1.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4') + with pytest.raises(bp.errors.ConnectorError): + conn1 = bp.connect.FixedPostNum(num=6, seed=1234)(pre_size=3, post_size=4) + conn1.require(bp.connect.CONN_MAT) + bp.connect.visualizeMat(conn1, 'FixedPostNum: num=6, pre_size=3, post_size=4') def test_gaussian_prob1(): - conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=1., include_self=False)(pre_size=100) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=1., include_self=False, pre_size=100') def test_gaussian_prob2(): - conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50)) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=4)(pre_size=(50, 50)) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, pre_size=(50, 50)') def test_gaussian_prob3(): - conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50)) - mat = conn.require(bp.connect.CONN_MAT) + conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(50, 50)) + mat = conn.require(bp.connect.CONN_MAT) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)') + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(50, 50)') def test_gaussian_prob4(): - conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10)) - conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - mat = conn.require(bp.connect.CONN_MAT) - bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)') + conn = bp.connect.GaussianProb(sigma=4, periodic_boundary=True)(pre_size=(10, 10, 10)) + conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + mat = conn.require(bp.connect.CONN_MAT) + bp.connect.visualizeMat(mat, 'GaussianProb: sigma=4, periodic_boundary=True, pre_size=(10, 10, 10)') def test_SmallWorld1(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) - conn(pre_size=10, post_size=10) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=False) + conn(pre_size=10, post_size=10) - mat = conn.require(bp.connect.CONN_MAT) + mat = conn.require(bp.connect.CONN_MAT) - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10') + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=False, pre_size=10, post_size=10') def test_SmallWorld3(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True) - conn(pre_size=20, post_size=20) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5, include_self=True) + conn(pre_size=20, post_size=20) - mat = conn.require(bp.connect.CONN_MAT) + mat = conn.require(bp.connect.CONN_MAT) - print('conn_mat', mat) + print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20') + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, include_self=True, pre_size=20, post_size=20') def test_SmallWorld2(): - conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5) - conn(pre_size=(100,), post_size=(100,)) + conn = bp.connect.SmallWorld(num_neighbor=2, prob=0.5) + conn(pre_size=(100,), post_size=(100,)) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)') + + +def test_ScaleFreeBA(): + conn = bp.connect.ScaleFreeBA(m=2) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, bp.connect.PRE_IDS, bp.connect.POST_IDS, bp.connect.PRE2POST, bp.connect.POST_IDS) print() print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'SmallWorld: num_neighbor=2, prob=0.5, pre_size=(100,), post_size=(100,)') - - -def test_ScaleFreeBA(): - conn = bp.connect.ScaleFreeBA(m=2) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size)) + bp.connect.visualizeMat(mat, 'ScaleFreeBA: m=2, pre_size=%s, post_size=%s' % (size, size)) def test_ScaleFreeBADual(): - conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) + conn = bp.connect.ScaleFreeBADual(m1=2, m2=3, p=0.4) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'ScaleFreeBADual: m1=2, m2=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) def test_PowerLaw(): - conn = bp.connect.PowerLaw(m=3, p=0.4) - for size in [100, (10, 20), (2, 10, 20)]: - conn(pre_size=size, post_size=size) - mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, - bp.connect.PRE_IDS, bp.connect.POST_IDS, - bp.connect.PRE2POST, bp.connect.POST_IDS) - print() - print('conn_mat', mat) - bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) + conn = bp.connect.PowerLaw(m=3, p=0.4) + for size in [100, (10, 20), (2, 10, 20)]: + conn(pre_size=size, post_size=size) + mat, _, _, _, _ = conn.require(bp.connect.CONN_MAT, + bp.connect.PRE_IDS, bp.connect.POST_IDS, + bp.connect.PRE2POST, bp.connect.POST_IDS) + print() + print('conn_mat', mat) + bp.connect.visualizeMat(mat, 'PowerLaw: m=3, p=0.4, pre_size=%s, post_size=%s' % (size, size)) diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index 3102bc1d..de559de5 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -49,7 +49,6 @@ # operators from .op_register import * from .pre_syn_post import * -from .surrogate._compt import * from . import surrogate, event, sparse, jitconn # Variable and Objects for object-oriented JAX transformations diff --git a/brainpy/_src/math/compat_numpy.py b/brainpy/_src/math/compat_numpy.py index 213185df..0eb39145 100644 --- a/brainpy/_src/math/compat_numpy.py +++ b/brainpy/_src/math/compat_numpy.py @@ -103,6 +103,10 @@ _max = max +def _return(a): + return Array(a) + + def fill_diagonal(a, val, inplace=True): if a.ndim < 2: raise ValueError(f'Only support tensor has dimension >= 2, but got {a.shape}') @@ -120,30 +124,30 @@ def fill_diagonal(a, val, inplace=True): def zeros(shape, dtype=None): - return Array(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.zeros(shape, dtype=dtype)) def ones(shape, dtype=None): - return Array(jnp.ones(shape, dtype=dtype)) + return _return(jnp.ones(shape, dtype=dtype)) def empty(shape, dtype=None): - return Array(jnp.zeros(shape, dtype=dtype)) + return _return(jnp.zeros(shape, dtype=dtype)) def zeros_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) def ones_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.ones_like(a, dtype=dtype, shape=shape)) + return _return(jnp.ones_like(a, dtype=dtype, shape=shape)) def empty_like(a, dtype=None, shape=None): a = _as_jax_array_(a) - return Array(jnp.zeros_like(a, dtype=dtype, shape=shape)) + return _return(jnp.zeros_like(a, dtype=dtype, shape=shape)) def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: @@ -155,7 +159,7 @@ def array(a, dtype=None, copy=True, order="K", ndmin=0) -> Array: leaves = [_as_jax_array_(l) for l in leaves] a = tree_unflatten(tree, leaves) res = jnp.array(a, dtype=dtype, copy=copy, order=order, ndmin=ndmin) - return Array(res) + return _return(res) def asarray(a, dtype=None, order=None): @@ -167,13 +171,13 @@ def asarray(a, dtype=None, order=None): leaves = [_as_jax_array_(l) for l in leaves] arrays = tree_unflatten(tree, leaves) res = jnp.asarray(a=arrays, dtype=dtype, order=order) - return Array(res) + return _return(res) def arange(*args, **kwargs): args = [_as_jax_array_(a) for a in args] kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.arange(*args, **kwargs)) + return _return(jnp.arange(*args, **kwargs)) def linspace(*args, **kwargs): @@ -181,15 +185,15 @@ def linspace(*args, **kwargs): kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} res = jnp.linspace(*args, **kwargs) if isinstance(res, tuple): - return Array(res[0]), res[1] + return _return(res[0]), res[1] else: - return Array(res) + return _return(res) def logspace(*args, **kwargs): args = [_as_jax_array_(a) for a in args] kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()} - return Array(jnp.logspace(*args, **kwargs)) + return _return(jnp.logspace(*args, **kwargs)) def asanyarray(a, dtype=None, order=None): diff --git a/brainpy/_src/math/compat_tensorflow.py b/brainpy/_src/math/compat_tensorflow.py index 7e9168cf..e9e87e24 100644 --- a/brainpy/_src/math/compat_tensorflow.py +++ b/brainpy/_src/math/compat_tensorflow.py @@ -259,13 +259,13 @@ def segment_sum(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_sum(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_sum(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_prod(data: Union[Array, jnp.ndarray], @@ -311,13 +311,13 @@ def segment_prod(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_prod(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_prod(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_max(data: Union[Array, jnp.ndarray], @@ -363,13 +363,13 @@ def segment_max(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_max(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_max(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def segment_min(data: Union[Array, jnp.ndarray], @@ -415,13 +415,13 @@ def segment_min(data: Union[Array, jnp.ndarray], An array with shape :code:`(num_segments,) + data.shape[1:]` representing the segment sums. """ - return Array(jax.ops.segment_min(as_jax(data), - as_jax(segment_ids), - num_segments, - indices_are_sorted, - unique_indices, - bucket_size, - mode)) + return _return(jax.ops.segment_min(as_jax(data), + as_jax(segment_ids), + num_segments, + indices_are_sorted, + unique_indices, + bucket_size, + mode)) def cast(x, dtype): diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 9f3c5045..eab8b9b6 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -12,32 +12,37 @@ # Default computation mode. mode = NonBatchingMode() -# '''Default computation mode.''' +# Default computation mode. membrane_scaling = IdScaling() -# '''Default time step.''' +# Default time step. dt = 0.1 -# '''Default bool data type.''' +# Default bool data type. bool_ = jnp.bool_ -# '''Default integer data type.''' +# Default integer data type. int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32 -# '''Default float data type.''' +# Default float data type. float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32 -# '''Default complex data type.''' +# Default complex data type. complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 # register brainpy object as pytree bp_object_as_pytree = False + +# default return array type +numpy_func_return = 'bp_array' # 'bp_array','jax_array' + + if ti is not None: - # '''Default integer data type in Taichi.''' + # Default integer data type in Taichi. ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 - # '''Default float data type in Taichi.''' + # Default float data type in Taichi. ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 else: diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index d49e70f5..ebbb8b6a 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -169,6 +169,7 @@ def __init__( int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, + numpy_func_return: str = None, ) -> None: super().__init__() @@ -208,6 +209,12 @@ def __init__( assert isinstance(bp_object_as_pytree, bool), '"bp_object_as_pytree" must be a bool.' self.old_bp_object_as_pytree = defaults.bp_object_as_pytree + if numpy_func_return is not None: + assert isinstance(numpy_func_return, str), '"numpy_func_return" must be a string.' + assert numpy_func_return in ['bp_array', 'jax_array'], \ + f'"numpy_func_return" must be "bp_array" or "jax_array". Got {numpy_func_return}.' + self.old_numpy_func_return = defaults.numpy_func_return + self.dt = dt self.mode = mode self.membrane_scaling = membrane_scaling @@ -217,6 +224,7 @@ def __init__( self.int_ = int_ self.bool_ = bool_ self.bp_object_as_pytree = bp_object_as_pytree + self.numpy_func_return = numpy_func_return def __enter__(self) -> 'environment': if self.dt is not None: set_dt(self.dt) @@ -228,6 +236,7 @@ def __enter__(self) -> 'environment': if self.complex_ is not None: set_complex(self.complex_) if self.bool_ is not None: set_bool(self.bool_) if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.bp_object_as_pytree + if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.numpy_func_return return self def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: @@ -240,6 +249,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self.complex_ is not None: set_complex(self.old_complex) if self.bool_ is not None: set_bool(self.old_bool) if self.bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = self.old_bp_object_as_pytree + if self.numpy_func_return is not None: defaults.__dict__['numpy_func_return'] = self.old_numpy_func_return def clone(self): return self.__class__(dt=self.dt, @@ -250,7 +260,8 @@ def clone(self): complex_=self.complex_, float_=self.float_, int_=self.int_, - bp_object_as_pytree=self.bp_object_as_pytree) + bp_object_as_pytree=self.bp_object_as_pytree, + numpy_func_return=self.numpy_func_return) def __eq__(self, other): return id(self) == id(other) @@ -279,6 +290,7 @@ def __init__( batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, + numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, @@ -288,7 +300,8 @@ def __init__( bool_=bool_, membrane_scaling=membrane_scaling, mode=modes.TrainingMode(batch_size), - bp_object_as_pytree=bp_object_as_pytree) + bp_object_as_pytree=bp_object_as_pytree, + numpy_func_return=numpy_func_return) class batching_environment(environment): @@ -315,6 +328,7 @@ def __init__( batch_size: int = 1, membrane_scaling: scales.Scaling = None, bp_object_as_pytree: bool = None, + numpy_func_return: str = None, ): super().__init__(dt=dt, x64=x64, @@ -324,7 +338,8 @@ def __init__( bool_=bool_, mode=modes.BatchingMode(batch_size), membrane_scaling=membrane_scaling, - bp_object_as_pytree=bp_object_as_pytree) + bp_object_as_pytree=bp_object_as_pytree, + numpy_func_return=numpy_func_return) def set( @@ -337,6 +352,7 @@ def set( int_: type = None, bool_: type = None, bp_object_as_pytree: bool = None, + numpy_func_return: str = None, ): """Set the default computation environment. @@ -360,6 +376,8 @@ def set( The bool data type. bp_object_as_pytree: bool Whether to register brainpy object as pytree. + numpy_func_return: str + The array to return in all numpy functions. Support 'bp_array' and 'jax_array'. """ if dt is not None: assert isinstance(dt, float), '"dt" must a float.' @@ -396,6 +414,10 @@ def set( if bp_object_as_pytree is not None: defaults.__dict__['bp_object_as_pytree'] = bp_object_as_pytree + if numpy_func_return is not None: + assert numpy_func_return in ['bp_array', 'jax_array'], f'"numpy_func_return" must be "bp_array" or "jax_array".' + defaults.__dict__['numpy_func_return'] = numpy_func_return + set_environment = set diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 6c0a2ed4..181ee552 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -14,6 +14,12 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) +import platform +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + + seed = 1234 diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 6fb8d02e..dd1bafde 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -11,7 +11,12 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) -shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] +import platform +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + + shapes = [(100, 200), (1000, 10)] diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 67c18124..e42bd369 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -11,6 +11,12 @@ if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) +import platform +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + + shapes = [(100, 200), (1000, 10)] diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index cf2b2343..791c8d9f 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -10,6 +10,7 @@ from jax.tree_util import register_pytree_node_class from brainpy.errors import MathError +from . import defaults bm = None @@ -41,8 +42,8 @@ def _check_input_array(array): def _return(a): - if isinstance(a, jax.Array) and a.ndim > 0: - return Array(a) + if defaults.numpy_func_return == 'bp_array' and isinstance(a, jax.Array) and a.ndim > 0: + return Array(a) return a @@ -1087,7 +1088,7 @@ def unsqueeze(self, dim: int) -> 'Array': See :func:`brainpy.math.unsqueeze` """ - return Array(jnp.expand_dims(self.value, dim)) + return _return(jnp.expand_dims(self.value, dim)) def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': """ @@ -1119,7 +1120,7 @@ def expand_dims(self, axis: Union[int, Sequence[int]]) -> 'Array': self.expand_dims(axis)==self.expand_dims(axis[0]).expand_dims(axis[1])... expand_dims(axis[len(axis)-1]) """ - return Array(jnp.expand_dims(self.value, axis)) + return _return(jnp.expand_dims(self.value, axis)) def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': """ @@ -1136,9 +1137,7 @@ def expand_as(self, array: Union['Array', jax.Array, np.ndarray]) -> 'Array': typically not contiguous. Furthermore, more than one element of a expanded array may refer to a single memory location. """ - if not isinstance(array, Array): - array = Array(array) - return Array(jnp.broadcast_to(self.value, array.value.shape)) + return _return(jnp.broadcast_to(self.value, array)) def pow(self, index: int): return _return(self.value ** index) @@ -1228,7 +1227,7 @@ def absolute_(self): return self.abs_() def mul(self, value): - return Array(self.value * value) + return _return(self.value * value) def mul_(self, value): """ @@ -1404,7 +1403,7 @@ def clip_(self, return self def clone(self) -> 'Array': - return Array(self.value.copy()) + return _return(self.value.copy()) def copy_(self, src: Union['Array', jax.Array, np.ndarray]) -> 'Array': self.value = jnp.copy(_as_jax_array_(src)) @@ -1423,7 +1422,7 @@ def cov_with( fweights = _as_jax_array_(fweights) aweights = _as_jax_array_(aweights) r = jnp.cov(self.value, y, rowvar, bias, fweights, aweights) - return Array(r) + return _return(r) def expand(self, *sizes) -> 'Array': """ @@ -1459,7 +1458,7 @@ def expand(self, *sizes) -> 'Array': raise ValueError( f'The expanded size of the tensor ({sizes_list[base + i]}) must match the existing size ({v}) at non-singleton ' f'dimension {i}. Target sizes: {sizes}. Tensor sizes: {self.shape}') - return Array(jnp.broadcast_to(self.value, sizes_list)) + return _return(jnp.broadcast_to(self.value, sizes_list)) def tree_flatten(self): return (self.value,), None diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 53346a7d..b21ed2af 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -557,7 +557,7 @@ def load_state_dict( missing_keys = [] unexpected_keys = [] for name, node in nodes.items(): - r = node.load_state(state_dict[name], **kwargs) + r = node.load_state(state_dict[name] if name in state_dict else {}, **kwargs) if r is not None: missing, unexpected = r missing_keys.extend([f'{name}.{key}' for key in missing]) diff --git a/brainpy/_src/math/object_transform/tests/test_base.py b/brainpy/_src/math/object_transform/tests/test_base.py index c6f8f90d..ebad7eb0 100644 --- a/brainpy/_src/math/object_transform/tests/test_base.py +++ b/brainpy/_src/math/object_transform/tests/test_base.py @@ -106,8 +106,7 @@ def update(self, x): self.assertTrue(len(obj.nodes()) == 7) print(obj.nodes().keys()) - print("obj.nodes(method='relative'): ", - obj.nodes(method='relative').keys()) + print("obj.nodes(method='relative'): ", obj.nodes(method='relative').keys()) # print(jax.tree_util.tree_structure(obj)) with bm.environment(mode=bm.TrainingMode()): @@ -116,8 +115,7 @@ def update(self, x): self.assertTrue(len(obj.nodes()) == 7) print(obj.nodes().keys()) - print("obj.nodes(method='relative'): ", - obj.nodes(method='relative').keys()) + print("obj.nodes(method='relative'): ", obj.nodes(method='relative').keys()) # print(jax.tree_util.tree_structure(obj)) @@ -248,5 +246,49 @@ def test1(self): print() +class TestStateSavingAndLoading(unittest.TestCase): + def test_load_states(self): + class Object(bp.DynamicalSystem): + def __init__(self): + super().__init__() + + self.l1 = bp.layers.Dense(5, 10) + self.ls = bm.NodeList([bp.layers.Dense(10, 4), + bp.layers.Activation(bm.tanh), + bp.layers.Dropout(0.1), + bp.layers.Dense(4, 5), + bp.layers.Activation(bm.relu)]) + self.lif = bp.dyn.LifRef(5) + + def update(self, x): + x = self.l1(x) + for l in self.ls: + x = l(x) + return x + + with bm.training_environment(): + obj = Object() + variables = {k: dict(n.vars()) for k, n in obj.nodes(include_self=False).items()} + variables = {k: v for k, v in variables.items() if len(v) > 0} + + all_states = obj.state_dict() + all_states = {k: v for k, v in all_states.items() if len(v) > 0} + print(set(all_states.keys())) + print(set(variables.keys())) + + def not_close(x, y): + assert not bm.allclose(x, y) + def all_close(x, y): + assert bm.allclose(x, y) + + jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array) + + random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array) + jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array) + + obj.load_state_dict(random_state) + jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array) + + diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index ed687eea..21c222c0 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -3,6 +3,6 @@ compile_cpu_signature_with_numba) from .base import XLACustomOp from .utils import register_general_batching -from .taichi_aot_based import clean_caches, check_kernels_count +from .taichi_aot_based import clear_taichi_aot_caches, count_taichi_aot_kernels from .base import XLACustomOp from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index f9328906..595460ea 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -8,7 +8,7 @@ import re import shutil from functools import partial, reduce -from typing import Any, Sequence +from typing import Any, Sequence, Union import jax.core import numpy as np @@ -16,14 +16,17 @@ from jax.lib import xla_client from jaxlib.hlo_helpers import custom_call -from brainpy.errors import PackageMissingError from brainpy._src.dependency_check import (import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops) +from brainpy.errors import PackageMissingError from .utils import _shape_to_layout -### UTILS ### +taichi_cache_path = None + + +# --- UTILS ### # get the path of home directory on Linux, Windows, Mac def get_home_dir(): @@ -43,8 +46,18 @@ def encode_md5(source: str) -> str: return md5.hexdigest() + # check kernels count -def check_kernels_count() -> int: +def count_taichi_aot_kernels() -> int: + """ + Count the number of AOT compiled kernels. + + Returns + ------- + kernels_count: int + The number of AOT compiled kernels. + + """ if not os.path.exists(kernels_aot_path): return 0 kernels_count = 0 @@ -54,23 +67,37 @@ def check_kernels_count() -> int: kernels_count += len(dir2) return kernels_count -# clean caches -def clean_caches(kernels_name: list[str]=None): - if kernels_name is None: - if not os.path.exists(kernels_aot_path): - raise FileNotFoundError("The kernels cache folder does not exist. \ - Please define a kernel using `taichi.kernel` \ - and customize the operator using `bm.XLACustomOp` \ - before calling the operator.") - shutil.rmtree(kernels_aot_path) - print('Clean all kernel\'s cache successfully') + +def clear_taichi_aot_caches(kernels: Union[str, Sequence[str]] = None): + """ + Clean the cache of the AOT compiled kernels. + + Parameters + ---------- + kernels: str or list of str + The name of the kernel to be cleaned. If None, all the kernels will be cleaned. + """ + if kernels is None: + global taichi_cache_path + if taichi_cache_path is None: + from taichi._lib.utils import import_ti_python_core + taichi_cache_path = import_ti_python_core().get_repo_dir() + # clean taichi cache + if os.path.exists(taichi_cache_path): + shutil.rmtree(taichi_cache_path) + # clean brainpy-taichi AOT cache + if os.path.exists(kernels_aot_path): + shutil.rmtree(kernels_aot_path) return - for kernel_name in kernels_name: - try: + if isinstance(kernels, str): + kernels = [kernels] + if not isinstance(kernels, list): + raise TypeError(f'kernels_name must be a list of str, but got {type(kernels)}') + # clear brainpy kernel cache + for kernel_name in kernels: + if os.path.exists(os.path.join(kernels_aot_path, kernel_name)): shutil.rmtree(os.path.join(kernels_aot_path, kernel_name)) - except FileNotFoundError: - raise FileNotFoundError(f'Kernel {kernel_name} does not exist.') - print('Clean kernel\'s cache successfully') + # TODO # not a very good way @@ -104,7 +131,7 @@ def is_metal_supported(): return True -### VARIABLES ### +# --- VARIABLES ### home_path = get_home_dir() kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels') is_metal_device = is_metal_supported() @@ -122,7 +149,7 @@ def _check_kernel_exist(source_md5_encode: str) -> bool: return False -### KERNEL AOT BUILD ### +# --- KERNEL AOT BUILD ### def _array_to_field(dtype, shape) -> Any: @@ -212,7 +239,7 @@ def _build_kernel( kernel.__name__ = kernel_name -### KERNEL CALL PREPROCESS ### +# --- KERNEL CALL PREPROCESS ### # convert type to number type_number_map = { @@ -334,9 +361,6 @@ def _preprocess_kernel_call_gpu( return opaque - - - def _XlaOp_to_ShapedArray(c, xla_op): xla_op = c.get_shape(xla_op) return jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type()) @@ -376,7 +400,7 @@ def _compile_kernel(abs_ins, kernel, platform: str, **kwargs): try: os.removedirs(os.path.join(kernels_aot_path, source_md5_encode)) except Exception: - raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e + raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e # returns diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py index 5b27b2fd..b534435d 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py @@ -51,10 +51,10 @@ def test_taichi_clean_cache(): print(out) bm.clear_buffer_memory() - print('kernels: ', bm.check_kernels_count()) + print('kernels: ', bm.count_taichi_aot_kernels()) - bm.clean_caches() + bm.clear_taichi_aot_caches() - print('kernels: ', bm.check_kernels_count()) + print('kernels: ', bm.count_taichi_aot_kernels()) # test_taichi_clean_cache() diff --git a/brainpy/_src/math/others.py b/brainpy/_src/math/others.py index 94aeebb1..59588d3b 100644 --- a/brainpy/_src/math/others.py +++ b/brainpy/_src/math/others.py @@ -11,7 +11,7 @@ from .compat_numpy import fill_diagonal from .environment import get_dt, get_int from .interoperability import as_jax -from .ndarray import Array +from .ndarray import Array, _return __all__ = [ 'shared_args_over_time', @@ -79,7 +79,7 @@ def remove_diag(arr): """ if arr.ndim != 2: raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.') - eyes = Array(jnp.ones(arr.shape, dtype=bool)) + eyes = _return(jnp.ones(arr.shape, dtype=bool)) fill_diagonal(eyes, False) return jnp.reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1)) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index d0f74bf2..9ae012bc 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -1232,9 +1232,10 @@ def zipf(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optiona a = _check_py_seq(_as_jax_array(a)) if size is None: size = jnp.shape(a) - r = call(lambda x: np.random.zipf(x, size), + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + r = call(lambda x: np.random.zipf(x, size).astype(dtype), a, - result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Optional[Union[int, JAX_RAND_KEY]] = None): @@ -1242,8 +1243,10 @@ def power(self, a, size: Optional[Union[int, Sequence[int]]] = None, key: Option if size is None: size = jnp.shape(a) size = _size2shape(size) - r = call(lambda a: np.random.power(a=a, size=size), - a, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) + r = call(lambda a: np.random.power(a=a, size=size).astype(dtype), + a, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, @@ -1256,11 +1259,12 @@ def f(self, dfnum, dfden, size: Optional[Union[int, Sequence[int]]] = None, size = jnp.broadcast_shapes(jnp.shape(dfnum), jnp.shape(dfden)) size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden} + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) r = call(lambda x: np.random.f(dfnum=x['dfnum'], dfden=x['dfden'], - size=size), + size=size).astype(dtype), d, - result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequence[int]]] = None, @@ -1274,12 +1278,14 @@ def hypergeometric(self, ngood, nbad, nsample, size: Optional[Union[int, Sequenc jnp.shape(nbad), jnp.shape(nsample)) size = _size2shape(size) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) d = {'ngood': ngood, 'nbad': nbad, 'nsample': nsample} - r = call(lambda x: np.random.hypergeometric(ngood=x['ngood'], - nbad=x['nbad'], - nsample=x['nsample'], - size=size), - d, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + r = call(lambda d: np.random.hypergeometric(ngood=d['ngood'], + nbad=d['nbad'], + nsample=d['nsample'], + size=size).astype(dtype), + d, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, @@ -1288,8 +1294,10 @@ def logseries(self, p, size: Optional[Union[int, Sequence[int]]] = None, if size is None: size = jnp.shape(p) size = _size2shape(size) - r = call(lambda p: np.random.logseries(p=p, size=size), - p, result_shape=jax.ShapeDtypeStruct(size, jnp.int_)) + dtype = jax.dtypes.canonicalize_dtype(jnp.int_) + r = call(lambda p: np.random.logseries(p=p, size=size).astype(dtype), + p, + result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[int]]] = None, @@ -1303,11 +1311,12 @@ def noncentral_f(self, dfnum, dfden, nonc, size: Optional[Union[int, Sequence[in jnp.shape(nonc)) size = _size2shape(size) d = {'dfnum': dfnum, 'dfden': dfden, 'nonc': nonc} + dtype = jax.dtypes.canonicalize_dtype(jnp.float_) r = call(lambda x: np.random.noncentral_f(dfnum=x['dfnum'], dfden=x['dfden'], nonc=x['nonc'], - size=size), - d, result_shape=jax.ShapeDtypeStruct(size, jnp.float_)) + size=size).astype(dtype), + d, result_shape=jax.ShapeDtypeStruct(size, dtype)) return _return(r) # PyTorch compatibility # diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 40bcbb70..acedcff1 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -9,10 +9,15 @@ import brainpy as bp import brainpy.math as bm from brainpy._src.dependency_check import import_taichi - if import_taichi(error_if_not_found=False) is None: pytest.skip('no taichi', allow_module_level=True) +import platform +force_test = False # turn on to force test on windows locally +if platform.system() == 'Windows' and not force_test: + pytest.skip('skip windows', allow_module_level=True) + + seed = 1234 diff --git a/brainpy/_src/math/surrogate/_compt.py b/brainpy/_src/math/surrogate/_compt.py deleted file mode 100644 index 67b7d515..00000000 --- a/brainpy/_src/math/surrogate/_compt.py +++ /dev/null @@ -1,247 +0,0 @@ -# -*- coding: utf-8 -*- - -import warnings - -from jax import custom_gradient, numpy as jnp - -from brainpy._src.math.compat_numpy import asarray -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.environment import get_float -from brainpy._src.math.ndarray import Array - -__all__ = [ - 'spike_with_sigmoid_grad', - 'spike_with_linear_grad', - 'spike_with_gaussian_grad', - 'spike_with_mg_grad', - - 'spike2_with_sigmoid_grad', - 'spike2_with_linear_grad', -] - - -def _consistent_type(target, compare): - return as_jax(target) if not isinstance(compare, Array) else asarray(target) - - -@custom_gradient -def spike_with_sigmoid_grad(x: Array, scale: float = 100.): - """Spike function with the sigmoid surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.sigmoid_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: Array - The input data. - scale: float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.inv_square_grad()` instead.', UserWarning) - - x = as_jax(x) - z = jnp.asarray(x >= 0, dtype=get_float()) - - def grad(dE_dz): - dE_dz = as_jax(dE_dz) - dE_dx = dE_dz / (scale * jnp.abs(x) + 1.0) ** 2 - if scale is None: - return (_consistent_type(dE_dx, x),) - else: - dscale = jnp.zeros_like(scale) - return (dE_dx, dscale) - - return z, grad - - -@custom_gradient -def spike2_with_sigmoid_grad(x_new: Array, x_old: Array, scale: float = None): - """Spike function with the sigmoid surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.inv_square_grad2()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x_new: Array - The input data. - x_old: Array - The input data. - scale: optional, float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.inv_square_grad2()` instead.', UserWarning) - - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=get_float()) - - def grad(dE_dz): - _scale = 100. if scale is None else scale - dx_new = (dE_dz / (_scale * jnp.abs(x_new) + 1.0) ** 2) * jnp.asarray(x_old_comp, dtype=get_float()) - dx_old = -(dE_dz / (_scale * jnp.abs(x_old) + 1.0) ** 2) * jnp.asarray(x_new_comp, dtype=get_float()) - if scale is None: - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old)) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old), - _consistent_type(dscale, scale)) - - return z, grad - - -@custom_gradient -def spike_with_linear_grad(x: Array, scale: float = None): - """Spike function with the relu surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.relu_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: Array - The input data. - scale: float - The scaling factor. - """ - - warnings.warn('Use `brainpy.math.surrogate.relu_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _scale = 0.3 if scale is None else scale - dE_dx = dE_dz * jnp.maximum(1 - jnp.abs(x), 0) * _scale - if scale is None: - return (_consistent_type(dE_dx, x),) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dE_dx, x), _consistent_type(dscale, _scale)) - - return z, grad - - -@custom_gradient -def spike2_with_linear_grad(x_new: Array, x_old: Array, scale: float = 10.): - """Spike function with the linear surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.relu_grad2()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x_new: Array - The input data. - x_old: Array - The input data. - scale: float - The scaling factor. - """ - warnings.warn('Use `brainpy.math.surrogate.relu_grad2()` instead.', UserWarning) - - x_new_comp = x_new >= 0 - x_old_comp = x_old < 0 - z = jnp.asarray(jnp.logical_and(x_new_comp, x_old_comp), dtype=get_float()) - - def grad(dE_dz): - _scale = 0.3 if scale is None else scale - dx_new = (dE_dz * jnp.maximum(1 - jnp.abs(x_new), 0) * _scale) * jnp.asarray(x_old_comp, dtype=get_float()) - dx_old = -(dE_dz * jnp.maximum(1 - jnp.abs(x_old), 0) * _scale) * jnp.asarray(x_new_comp, dtype=get_float()) - if scale is None: - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old)) - else: - dscale = jnp.zeros_like(_scale) - return (_consistent_type(dx_new, x_new), - _consistent_type(dx_old, x_old), - _consistent_type(dscale, scale)) - - return z, grad - - -def _gaussian(x, mu, sigma): - return jnp.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / jnp.sqrt(2 * jnp.pi) / sigma - - -@custom_gradient -def spike_with_gaussian_grad(x, sigma=None, scale=None): - """Spike function with the Gaussian surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.gaussian_grad()`` instead. - Will be removed after version 2.4.0. - - """ - - warnings.warn('Use `brainpy.math.surrogate.gaussian_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _scale = 0.5 if scale is None else scale - _sigma = 0.5 if sigma is None else sigma - dE_dx = dE_dz * _gaussian(x, 0., _sigma) * _scale - returns = (_consistent_type(dE_dx, x),) - if sigma is not None: - returns += (_consistent_type(jnp.zeros_like(_sigma), sigma),) - if scale is not None: - returns += (_consistent_type(jnp.zeros_like(_scale), scale),) - return returns - - return z, grad - - -@custom_gradient -def spike_with_mg_grad(x, h=None, s=None, sigma=None, scale=None): - """Spike function with the multi-Gaussian surrogate gradient. - - .. deprecated:: 2.3.1 - Please use ``brainpy.math.surrogate.multi_sigmoid_grad()`` instead. - Will be removed after version 2.4.0. - - Parameters - ---------- - x: ndarray - The variable to judge spike. - h: float - The hyper-parameters of approximate function - s: float - The hyper-parameters of approximate function - sigma: float - The gaussian sigma. - scale: float - The gradient scale. - """ - - warnings.warn('Use `brainpy.math.surrogate.multi_sigmoid_grad()` instead.', UserWarning) - - z = jnp.asarray(x >= 0., dtype=get_float()) - - def grad(dE_dz): - _sigma = 0.5 if sigma is None else sigma - _scale = 0.5 if scale is None else scale - _s = 6.0 if s is None else s - _h = 0.15 if h is None else h - dE_dx = dE_dz * (_gaussian(x, mu=0., sigma=_sigma) * (1. + _h) - - _gaussian(x, mu=_sigma, sigma=_s * _sigma) * _h - - _gaussian(x, mu=-_sigma, sigma=_s * _sigma) * _h) * _scale - returns = (_consistent_type(dE_dx, x),) - if h is not None: - returns += (_consistent_type(jnp.zeros_like(_h), h),) - if s is not None: - returns += (_consistent_type(jnp.zeros_like(_s), s),) - if sigma is not None: - returns += (_consistent_type(jnp.zeros_like(_sigma), sigma),) - if scale is not None: - returns += (_consistent_type(jnp.zeros_like(_scale), scale),) - return returns - - return z, grad - diff --git a/brainpy/_src/math/tests/test_environment.py b/brainpy/_src/math/tests/test_environment.py new file mode 100644 index 00000000..83315899 --- /dev/null +++ b/brainpy/_src/math/tests/test_environment.py @@ -0,0 +1,15 @@ +import unittest + +import jax + +import brainpy.math as bm + + +class TestEnvironment(unittest.TestCase): + def test_numpy_func_return(self): + with bm.environment(numpy_func_return='jax_array'): + a = bm.random.randn(3, 3) + self.assertTrue(isinstance(a, jax.Array)) + with bm.environment(numpy_func_return='bp_array'): + a = bm.random.randn(3, 3) + self.assertTrue(isinstance(a, bm.Array)) diff --git a/brainpy/_src/math/tests/test_random.py b/brainpy/_src/math/tests/test_random.py index 63b77064..1621f43d 100644 --- a/brainpy/_src/math/tests/test_random.py +++ b/brainpy/_src/math/tests/test_random.py @@ -1,8 +1,10 @@ +import platform import unittest import jax.numpy as jnp import jax.random as jr import numpy as np +import pytest import brainpy.math as bm import brainpy.math.random as br @@ -354,11 +356,13 @@ def test_hypergeometric1(self): a = bm.random.hypergeometric(10, 10, 10, 20) self.assertTupleEqual(a.shape, (20,)) + @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error') def test_hypergeometric2(self): br.seed() a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]]) self.assertTupleEqual(a.shape, (2, 2)) + @pytest.mark.skipif(platform.system() == 'Windows', reason='Windows jaxlib error') def test_hypergeometric3(self): br.seed() a = bm.random.hypergeometric(8, [10, 4], [[5, 2], [5, 5]], size=(3, 2, 2)) diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 9a64f9f2..08a070f0 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -44,13 +44,6 @@ del jnp, config -from brainpy._src.math.surrogate._compt import ( - spike_with_sigmoid_grad as spike_with_sigmoid_grad, - spike_with_linear_grad as spike_with_linear_grad, - spike_with_gaussian_grad as spike_with_gaussian_grad, - spike_with_mg_grad as spike_with_mg_grad, -) - from brainpy._src.math import defaults from brainpy._src.deprecations import deprecation_getattr from brainpy._src.dependency_check import import_taichi, import_numba diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index c0fcb67a..f383c1a2 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -2,8 +2,8 @@ from brainpy._src.math.op_register import ( CustomOpByNumba, compile_cpu_signature_with_numba, - clean_caches, - check_kernels_count, + clear_taichi_aot_caches, + count_taichi_aot_kernels, ) from brainpy._src.math.op_register.base import XLACustomOp diff --git a/docs/apis/brainpy.math.op_register.rst b/docs/apis/brainpy.math.op_register.rst index a50b4d30..13ce518c 100644 --- a/docs/apis/brainpy.math.op_register.rst +++ b/docs/apis/brainpy.math.op_register.rst @@ -22,6 +22,23 @@ General Operator Customization Interface +CPU Operator Customization with Taichi +------------------------------------- + +.. currentmodule:: brainpy.math +.. automodule:: brainpy.math + +.. autosummary:: + :toctree: generated/ + + clear_taichi_aot_caches + count_taichi_aot_kernels + + + + + + CPU Operator Customization with Numba ------------------------------------- @@ -34,7 +51,6 @@ CPU Operator Customization with Numba :template: classtemplate.rst CustomOpByNumba - XLACustomOp .. autosummary:: @@ -43,3 +59,17 @@ CPU Operator Customization with Numba register_op_with_numba compile_cpu_signature_with_numba + + +Operator Autograd Customization +------------------------------- + +.. currentmodule:: brainpy.math +.. automodule:: brainpy.math + +.. autosummary:: + :toctree: generated/ + + defjvp + + diff --git a/examples/dynamics_simulation/ei_nets.py b/examples/dynamics_simulation/ei_nets.py index f9852745..7923c93d 100644 --- a/examples/dynamics_simulation/ei_nets.py +++ b/examples/dynamics_simulation/ei_nets.py @@ -213,7 +213,7 @@ def __init__(self): super().__init__() self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 1.)) - self.delay = bp.VarDelay(self.N.spike, entries={'delay': 2}) + self.delay = bp.VarDelay(self.N.spike, entries={'delay': 0.}) self.syn1 = bp.dyn.Expon(size=3200, tau=5.) self.syn2 = bp.dyn.Expon(size=800, tau=10.) self.E = bp.dyn.VanillaProj( @@ -228,7 +228,7 @@ def __init__(self): ) def update(self, input): - spk = self.delay.at('I') + spk = self.delay.at('delay') self.E(self.syn1(spk[:3200])) self.I(self.syn2(spk[3200:])) self.delay(self.N(input)) From 5f1b9051270bbefa8e82b8f0386a9d3f35ce212f Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Sun, 3 Mar 2024 13:38:14 +0800 Subject: [PATCH 08/21] [ci] Fix windows pytest fatal exception (#644) Update CI.yml --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 7f46c959..01bdd87f 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -113,4 +113,4 @@ jobs: - name: Test with pytest run: | cd brainpy - pytest _src/ + pytest _src/ -p no:faulthandler From a8377cba17f9193d80a58bd5668801ba70ef1633 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Sun, 3 Mar 2024 14:03:56 +0800 Subject: [PATCH 09/21] [math] Support more than 8 parameters when using taichi gpu operator customization (#642) * Update taichi_aot_based.py * Update _event_matvec.py * Update _event_matvec.py --- brainpy/_src/math/jitconn/event_matvec.py | 2 +- brainpy/_src/math/op_register/taichi_aot_based.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/brainpy/_src/math/jitconn/event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py index 27998038..a22aac75 100644 --- a/brainpy/_src/math/jitconn/event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -1157,4 +1157,4 @@ def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( cpu_kernel=_event_mv_prob_normal_cpu, gpu_kernel=_event_mv_prob_normal_gpu - ) + ) \ No newline at end of file diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 595460ea..2a8cb3b6 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -324,11 +324,12 @@ def _preprocess_kernel_call_gpu( kernel_path = os.path.join(kernels_aot_path, source_md5_encode) # other args + param_total_num = len(ins) + len(outs) in_out_num = [len(ins), len(outs)] - in_out_type_list = [0] * 8 - in_out_dim_count_list = [0] * 8 - in_out_elem_count_list = [0] * 8 - in_out_shape_list = [0] * 64 + in_out_type_list = [0] * param_total_num + in_out_dim_count_list = [0] * param_total_num + in_out_elem_count_list = [0] * param_total_num + in_out_shape_list = [0] * param_total_num * 8 for i, value in enumerate(ins.values()): in_out_type_list[i] = type_number_map[value[0]] From 5e541076371493296ca7a2e954aa5c692a420ab7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 3 Mar 2024 14:04:21 +0800 Subject: [PATCH 10/21] Doc for ``brainpylib>=0.3.0`` (#645) * update docs * update changelogs * add `noise` option to neurons in `brainpy.dyn` --- .github/workflows/docs.yml | 43 - brainpy-changelog.md | 2663 +++++++++++++++++ brainpy/_src/dyn/neurons/hh.py | 31 +- brainpy/_src/dyn/neurons/lif.py | 124 +- .../_src/dynold/neurons/biological_models.py | 13 +- brainpy/_src/dynold/neurons/reduced_models.py | 30 +- brainpylib-changelog.md | 62 + changelog.rst | 1083 ------- docs/api.rst | 3 +- docs/conf.py | 3 +- docs/index.rst | 15 +- docs/quickstart/installation.rst | 30 +- examples/operator_customization/event_ell.py | 40 + setup.py | 9 +- 14 files changed, 2939 insertions(+), 1210 deletions(-) delete mode 100644 .github/workflows/docs.yml create mode 100644 brainpy-changelog.md create mode 100644 brainpylib-changelog.md delete mode 100644 changelog.rst create mode 100644 examples/operator_customization/event_ell.py diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml deleted file mode 100644 index 0c515d77..00000000 --- a/.github/workflows/docs.yml +++ /dev/null @@ -1,43 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: Make documentation - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - - -jobs: - make_docs: - runs-on: - group: Default - labels: self-hosted - - steps: - - uses: actions/checkout@v4 - - uses: conda-incubator/setup-miniconda@v3 - with: - auto-update-conda: true - python-version: "3.10" - miniconda-version: "latest" - - name: Conda info - shell: bash -el {0} - run: conda info - - name: Install dependencies - shell: bash -el {0} - run: | - conda activate - python -m pip install --upgrade pip - python -m pip install flake8 pytest - if [ -f requirements-doc.txt ]; then pip install -r requirements-doc.txt; fi - pip uninstall brainpy -y - python setup.py install - - name: Make docs - shell: bash -el {0} - run: | - conda activate - cd ~/brainpy_docs/docs - make html \ No newline at end of file diff --git a/brainpy-changelog.md b/brainpy-changelog.md new file mode 100644 index 00000000..c949b701 --- /dev/null +++ b/brainpy-changelog.md @@ -0,0 +1,2663 @@ +# Release notes (``brainpy``) + + +## brainpy>2.3.x + + +### Version 2.5.0 + + +This release contains many new features and fixes. It is the first release with a mature solution for Brain Dynamics Operator Customization on both CPU and GPU platforms. + + +#### New Features + +1. Add synapse projection with Delta synapse models through ``brainpy.dyn.HalfProjDelta`` and ``brainpy.dyn.FullProjDelta``. +2. Add ``brainpy.math.exprel``, and change the code in the corresponding HH neuron models to improve numerical computation accuracy. These changes can significantly improve the numerical integration accuracy of HH-like models under x32 computation. +3. Add ``brainpy.reset_level()`` decorator so that the state resetting order can be customized by users. +4. Add ``brainpy.math.ein_rearrange``, ``brainpy.math.ein_reduce``, and ``brainpy.math.ein_repeat`` functions +5. Add ``brainpy.math.scan`` transformation. +6. Rebase all customized operators using Taichi JIT compiler. On the CPU platform, the speed performance can be boosted ten to hundred times. On the GPU platforms, the flexibility can be greatly improved. +7. Many bug fixes. +8. A new version of ``brainpylib>=0.2.4`` has been released, supporting operator customization through the Taichi compiler. The supported backends include Linux, Windows, MacOS Intel, and MacOS M1 platforms. Tutorials please see https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html + +#### What's Changed +* [docs] Add taichi customized operators tutorial by @Routhleck in https://github.com/brainpy/BrainPy/pull/545 +* [docs] Optimize tutorial code in `operator_custom_with_taichi.ipynb` of documentations by @Routhleck in https://github.com/brainpy/BrainPy/pull/546 +* [running] fix multiprocessing bugs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/547 +* [docs] Fix typo in docs by @Routhleck in https://github.com/brainpy/BrainPy/pull/549 +* :arrow_up: Bump conda-incubator/setup-miniconda from 2 to 3 by @dependabot in https://github.com/brainpy/BrainPy/pull/551 +* updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/550 +* ``brainpy.math.defjvp`` and ``brainpy.math.XLACustomOp.defjvp`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/554 +* :arrow_up: Bump actions/setup-python from 4 to 5 by @dependabot in https://github.com/brainpy/BrainPy/pull/555 +* Fix ``brainpy.math.ifelse`` bugs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/556 +* [math & dyn] add ``brainpy.math.exprel``, and change the code in the corresponding HH neuron models to improve numerical computation accuracy by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/557 +* Update README by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/558 +* [doc] add conductance neuron model tutorial by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/559 +* Doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/560 +* add `brainpy.math.functional_vector_grad` and `brainpy.reset_level()` decorator by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/561 +* [math] change the internal implementation of surrogate function by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/562 +* Math by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/563 +* [doc] update citations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/564 +* add support for multi-class margin loss by @charlielam0615 in https://github.com/brainpy/BrainPy/pull/566 +* Support for Delta synapse projections by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/568 +* [math] Add taichi customized operators(event csrmv, csrmv, jitconn event mv, jitconn mv) by @Routhleck in https://github.com/brainpy/BrainPy/pull/553 +* fix doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/571 +* Fix default math parameter setting bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/572 +* fix bugs in `brainpy.math.random.truncated_normal` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/574 +* [doc] fix doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/576 +* fix bugs in truncated_normal; add TruncatedNormal init. by @charlielam0615 in https://github.com/brainpy/BrainPy/pull/575 +* [Dyn] Fix alpha synapse bugs by @ztqakita in https://github.com/brainpy/BrainPy/pull/578 +* fix `brainpy.math.softplus` and `brainpy.dnn.SoftPlus` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/581 +* add `TruncatedNormal` to `initialize.py` by @charlielam0615 in https://github.com/brainpy/BrainPy/pull/583 +* Fix `_format_shape` in `random_inits.py` by @charlielam0615 in https://github.com/brainpy/BrainPy/pull/584 +* fix bugs in `truncated_normal` by @charlielam0615 in https://github.com/brainpy/BrainPy/pull/585 +* [dyn] fix warning of reset_state by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/587 +* [math] upgrade variable retrival by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/589 +* [math & dnn] add `brainpy.math.unflatten` and `brainpy.dnn.Unflatten` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/588 +* [math] add ``ein_rearrange``, ``ein_reduce``, and ``ein_repeat`` functions by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/590 +* [math] Support taichi customized op with metal cpu backend by @Routhleck in https://github.com/brainpy/BrainPy/pull/579 +* Doc fix and standardize Dual Exponential model again by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/591 +* update doc, upgrade reset_state, update projection models by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/592 +* [taichi] Make taichi caches more transparent and Add clean caches function by @Routhleck in https://github.com/brainpy/BrainPy/pull/596 +* [test] remove test skip on macos, since brainpylib supports taichi interface on macos by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/597 +* [dyn] add `clear_input` in the `step_run` function. by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/601 +* [math] Refactor taichi operators by @Routhleck in https://github.com/brainpy/BrainPy/pull/598 +* [math] fix `brainpy.math.scan` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/604 +* ``disable_ jit`` support in ``brainpy.math.scan`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/606 +* [math] Remove the logs that `taichi.init()` print by @Routhleck in https://github.com/brainpy/BrainPy/pull/609 +* Version control in Publish.yml CI by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/610 + +#### New Contributors +* @charlielam0615 made their first contribution in https://github.com/brainpy/BrainPy/pull/566 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.4.6...V2.5.0 + + + + +### Version 2.4.6 + +This release contains more than 130 commit updates, and has provided several new features. + + +#### New Features + + +##### 1. surrogate gradient functions are more transparent. + +New instances can be used to compute the surrogate gradients. For example: + +```python +import brainpy.math as bm +fun = bm.surrogate.Sigmoid() + +# forward function +spk = fun(membrane_potential) + +# backward function +dV = fun.surrogate_grad(1., membrane_potential) + +# surrogate forward function +surro_spk = fun.surrogate_fun(membrane_potential) +``` + +##### 2. Add ``brainpy.math.eval_shape`` for evaluating the all dynamical variables used in the target function. + +This function is similar to ``jax.eval_shape`` which has no FLOPs, while it can extract all variables used in the target function. For example: + +```python +net = ... # any dynamical system +inputs = ... # inputs to the dynamical system +variables, outputs= bm.eval_shape(net, inputs) +# "variables" are all variables used in the target "net" +``` + +In future, this function will be used everywhere to transform all jax transformations into brainpy's oo transformations. + +##### 3. Generalize tools and interfaces for state managements. + +For a single object: +- The ``.reset_state()`` defines the state resetting of all local variables in this node. +- The ``.load_state()`` defines the state loading from external disks (typically, a dict is passed into this ``.load_state()`` function). +- The ``.save_state()`` defines the state saving to external disks (typically, the ``.save_state()`` function generates a dict containing all variable values). + +Here is an example to define a full class of ``brainpy.DynamicalSystem``. + +```python +import brainpy as bp + +class YouDynSys(bp.DynamicalSystem): + def __init__(self, ): # define parameters + self.par1 = .... + self.num = ... + + def reset_state(self, batch_or_mode=None): # define variables + self.a = bp.init.variable_(bm.zeros, (self.num,), batch_or_mode) + + def load_state(self, state_dict): # load states from an external dict + self.a.value = bm.as_jax(state_dict['a']) + + def save_state(self): # save states as an external dict + return {'a': self.a.value} +``` + + +For a complex network model, brainpy provide unified state managment interface for initializing, saving, and loading states. +- The ``brainpy.reset_state()`` defines the state resetting of all variables in this node and its children nodes. +- The ``brainpy.load_state()`` defines the state loading from external disks of all variables in the node and its children. +- The ``brainpy.save_state()`` defines the state saving to external disks of all variables in the node and its children. +- The ``brainpy.clear_input()`` defines the clearing of all input variables in the node and its children. + + + + +##### 4. Unified brain simulation and brain-inspired computing interface through automatic membrane scaling. + +The same model used in brain simulation can be easily transformed into the one used for brain-inspired computing for training. For example, + + +```python +class EINet(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.)) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = bp.dyn.ProjAlignPost1( + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=3200, post=4000), weight=bp.init.Normal(0.6, 0.01)), + syn=bp.dyn.Expon(size=4000, tau=5.), + out=bp.dyn.COBA(E=0.), + post=self.N + ) + self.I = bp.dyn.ProjAlignPost1( + comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(0.02, pre=800, post=4000), weight=bp.init.Normal(6.7, 0.01)), + syn=bp.dyn.Expon(size=4000, tau=10.), + out=bp.dyn.COBA(E=-80.), + post=self.N + ) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + +# used for brain simulation +with bm.environment(mode=bm.nonbatching_mode): + net = EINet() + + +# used for brain-inspired computing +# define the `membrane_scaling` parameter +with bm.environment(mode=bm.TrainingMode(128), membrane_scaling=bm.Scaling.transform([-60., -50.])): + net = EINet() +``` + + + +##### 5. New apis for operator customization on CPU and GPU devices through ``brainpy.math.XLACustomOp``. + +Starting from this release, brainpy introduces [Taichi](https://github.com/taichi-dev/taichi) for operator customization. Now, users can write CPU and GPU operators through numba and taichi syntax on CPU device, and taichi syntax on GPu device. Particularly, to define an operator, user can use: + +```python + +import numba as nb +import taichi as ti +import numpy as np +import jax +import brainpy.math as bm + + +@nb.njit +def numba_cpu_fun(a, b, out_a, out_b): + out_a[:] = a + out_b[:] = b + + +@ti.kernel +def taichi_gpu_fun(a, b, out_a, out_b): + for i in range(a.size): + out_a[i] = a[i] + for i in range(b.size): + out_b[i] = b[i] + + +prim = bm.XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun) +a2, b2 = prim(np.random.random(1000), np.random.random(1000), + outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32), + jax.ShapeDtypeStruct(1000, dtype=np.float32)]) + +``` + +##### 6. Generalized STDP models which are compatible with diverse synapse models. + +See https://github.com/brainpy/BrainPy/blob/master/brainpy/_src/dyn/projections/tests/test_STDP.py + + +#### What's Changed +* [bug] fix compatible bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/508 +* [docs] add low-level op customization by @ztqakita in https://github.com/brainpy/BrainPy/pull/507 +* Compatible with `jax==0.4.16` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/511 +* updates for parallelization support by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/514 +* Upgrade surrogate gradient functions by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/516 +* [doc] update operator customization by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/517 +* Updates for OO transforma and surrogate functions by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/519 +* [dyn] add neuron scaling by @ztqakita in https://github.com/brainpy/BrainPy/pull/520 +* State saving, loading, and resetting by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/521 +* [delay] rewrite previous delay APIs so that they are compatible with new brainpy version by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/522 +* [projection] upgrade projections so that APIs are reused across different models by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/523 +* [math] the interface for operator registration by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/524 +* FIx bug in Delay by @ztqakita in https://github.com/brainpy/BrainPy/pull/525 +* Fix bugs in membrane scaling by @ztqakita in https://github.com/brainpy/BrainPy/pull/526 +* [math] Implement taichi op register by @Routhleck in https://github.com/brainpy/BrainPy/pull/527 +* Link libtaichi_c_api.so when import brainpylib by @Routhleck in https://github.com/brainpy/BrainPy/pull/528 +* update taichi op customization by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/529 +* Fix error message by @HoshinoKoji in https://github.com/brainpy/BrainPy/pull/530 +* [math] remove the hard requirement of `taichi` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/531 +* [math] Resolve encoding of source kernel when ti.func is nested in ti… by @Routhleck in https://github.com/brainpy/BrainPy/pull/532 +* [math] new abstract function for XLACustomOp, fix its bugs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/534 +* [math] fix numpy array priority by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/533 +* [brainpy.share] add category shared info by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/535 +* [doc] update documentations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/536 +* [doc] update doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/537 +* [dyn] add `brainpy.reset_state()` and `brainpy.clear_input()` for more consistent and flexible state managements by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/538 +* [math] simplify the taichi AOT operator customization interface by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/540 +* [dyn] add `save_state`, `load_state`, `reset_state`, and `clear_input` helpers by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/542 +* [dyn] update STDP APIs on CPUs and fix bugs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/543 + +#### New Contributors +* @HoshinoKoji made their first contribution in https://github.com/brainpy/BrainPy/pull/530 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.4.5...V2.4.6 + + + + + +### Version 2.4.5 + + +#### New Features + +- A new version of ``brainpylib==0.1.10`` has been released. In this release, we have fixed some bugs of brainpy dedicated GPU operators. Users can freely use them in any application. +- Correspondingly, dedicated operators in ``brainpy.math`` have been refined. +- ``.tracing_variable()`` has been created to support tracing ``Variable``s during computations and compilations. Example usage please see #472 +- Add a new random API for creating multiple random keys: ``brainpy.math.random.split_keys()``. +- Fix bugs, including + - ``brainpy.dnn.AllToAll`` module + - RandomState. + - ``brainpy.math.cond`` and ``brainpy.math.while_loop`` when variables are used in both branches + +#### What's Changed +* Creat random key automatically when it is detected by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/461 +* [encoding] upgrade encoding methods by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/464 +* fix #466 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/467 +* Update operators for compatible with ``brainpylib>=0.1.10`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/468 +* Support tracing ``Variable`` during computation and compilation by using ``tracing_variable()`` function by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/472 +* Add code of conduct and contributing guides by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/473 +* add Funding and Development roadmap by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/475 +* Create SECURITY.md by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/474 +* Create dependabot.yml by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/476 +* update maintainence info in README by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/479 +* :arrow_up: Bump actions/setup-python from 2 to 4 by @dependabot in https://github.com/brainpy/BrainPy/pull/477 +* :arrow_up: Bump actions/checkout from 2 to 4 by @dependabot in https://github.com/brainpy/BrainPy/pull/478 +* ad acknowledgment.md by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/482 +* update quickstart of `simulating a brain dynamics model` with new APIs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/483 +* update advanced tutorials by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/484 +* [docs] Update installation.rst by @Routhleck in https://github.com/brainpy/BrainPy/pull/485 +* update requirements by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/486 +* [doc] update docs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/487 +* [doc] update docs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/488 +* Decouple Online and Offline training algorithms as ``brainpy.mixin.SupportOnline`` and `brainpy.mixin.SupportOffline` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/489 +* [dyn] add STDP_Song2000 LTP model by @ztqakita in https://github.com/brainpy/BrainPy/pull/481 +* update STDP by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/491 +* [doc] update the API of `brainpy.dyn` module & add synaptic plasticity module by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/492 +* fix bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/493 +* [math] fix bugs in `cond` and `while_loop` when same variables are used in both branches by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/494 +* [docs] add BrainPy docker and docs by @ztqakita in https://github.com/brainpy/BrainPy/pull/496 +* [docs] update README and installation by @ztqakita in https://github.com/brainpy/BrainPy/pull/499 +* :arrow_up: Bump docker/build-push-action from 4 to 5 by @dependabot in https://github.com/brainpy/BrainPy/pull/498 +* :arrow_up: Bump docker/login-action from 2 to 3 by @dependabot in https://github.com/brainpy/BrainPy/pull/497 +* Add strings in bp._src.dyn.bio_models and abstract_models by @AkitsuFaye in https://github.com/brainpy/BrainPy/pull/500 +* [reset] update logics of state reset in `DynamicalSystem` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/501 +* [doc] upgrade docs with the latest APIs, fix #463 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/502 +* [doc] add synapse model documentations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/503 +* Changed the order of code blocks in the docs of hh models and lif models by @AkitsuFaye in https://github.com/brainpy/BrainPy/pull/505 +* [mode] move recurrent models in brainpy.dnn model into `brainpy.dyn` module by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/506 + +#### New Contributors +* @dependabot made their first contribution in https://github.com/brainpy/BrainPy/pull/477 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.4.4...V2.4.5 + + + + + + +### Version 2.4.4 + + + +This release has fixed several bugs and updated the sustainable documentation. + +#### What's Changed +* [mixin] abstract the behavior of supporting input projection by ``brainpy.mixin.ReceiveInputProj`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/428 +* Update delays, models, and projections by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/429 +* Compatible with `jax=0.4.14` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/431 +* Add new tests by @yygf123 in https://github.com/brainpy/BrainPy/pull/430 +* Add NonBatchingMode function by @yygf123 in https://github.com/brainpy/BrainPy/pull/433 +* [connect] Complete `FixedTotalNum` class and fix bugs by @Routhleck in https://github.com/brainpy/BrainPy/pull/434 +* Update the document "Concept 2: Dynamical System" by @yygf123 in https://github.com/brainpy/BrainPy/pull/435 +* [docs] Update three part of tutorial toolbox by @Routhleck in https://github.com/brainpy/BrainPy/pull/436 +* [docs] Update index.rst for surrogate gradient by @Routhleck in https://github.com/brainpy/BrainPy/pull/437 +* Reconstruct BrainPy documentations by @ztqakita in https://github.com/brainpy/BrainPy/pull/438 +* Renew doc requirements.txt by @ztqakita in https://github.com/brainpy/BrainPy/pull/441 +* Compatibility updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/442 +* update docs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/443 +* Update optimizer by @yygf123 in https://github.com/brainpy/BrainPy/pull/451 +* [docs] Update custom saving and loading by @Routhleck in https://github.com/brainpy/BrainPy/pull/439 +* [doc] add new strings in bp._src.dyn.hh.py and bp._src.dyn.lif.py by @AkitsuFaye in https://github.com/brainpy/BrainPy/pull/454 +* Serveral updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/452 +* Update doc bug in index.rst by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/458 +* add `brainpy.dyn.Alpha` synapse model by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/459 +* [doc] update ODE doc by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/460 + +#### New Contributors +* @AkitsuFaye made their first contribution in https://github.com/brainpy/BrainPy/pull/454 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.4.3...V2.4.4 + + + + + +### Version 2.4.3 + + +This release has standardized the modeling of DNN and SNN models by two intercorrelated packages: ``brainpy.dnn`` and ``brainpy.dyn``. + +Overall, the modeling of brain dynamics in this release has the following advantages: + +- the automatic merging of the duplicate synapses, keeping the minimal device memory +- easy model and data parallelization across multiple devices +- easy integration with artificial neural networks +- a new abstraction that decouples dynamics from communication +- the unified ``DynamicalSystem`` interface + +#### New Features + +1. Support to define ion channel models which rely on multiple ions. For example, + +```python + +class HH(bp.dyn.CondNeuGroup): + def __init__(self, size): + super().__init__(size) + self.k = bp.dyn.PotassiumFixed(size) + self.ca = bp.dyn.CalciumFirstOrder(size) + + self.kca = bp.dyn.mix_ions(self.k, self.ca) # Ion that mixing Potassium and Calcium + self.kca.add_elem(ahp=bp.dyn.IAHP_De1994v2(size)) # channel that relies on both Potassium and Calcium + +``` + +2. New style ``.update()`` function in ``brainpy.DynamicalSystem`` which resolves all compatible issues. Starting from this version, all ``update()`` no longer needs to receive a global shared argument such as ``tdi``. + +```python + +class YourDynSys(bp.DynamicalSystem): + def update(self, x): + t = bp.share['t'] + dt = bp.share['dt'] + i = bp.share['i'] + ... + +``` + +3. Optimize the connection-building process when using ``brainpy.conn.ScaleFreeBA``, ``brainpy.conn.ScaleFreeBADual``, ``brainpy.conn.PowerLaw`` + +4. New dual exponential model ``brainpy.dyn.DualExponV2`` can be aligned with post dimension. + +5. More synaptic projection abstractions, including + - ``brainpy.dyn.VanillaProj`` + - ``brainpy.dyn.ProjAlignPostMg1`` + - ``brainpy.dyn.ProjAlignPostMg2`` + - ``brainpy.dyn.ProjAlignPost1`` + - ``brainpy.dyn.ProjAlignPost2`` + - ``brainpy.dyn.ProjAlignPreMg1`` + - ``brainpy.dyn.ProjAlignPreMg2`` + +5. Fix compatible issues, fix unexpected bugs, and improve the model tests. + + + +#### What's Changed +* [connect] Optimize the connector about ScaleFreeBA, ScaleFreeBADual, PowerLaw by @Routhleck in https://github.com/brainpy/BrainPy/pull/412 +* [fix] bug of `connect.base.py`'s `require` function by @Routhleck in https://github.com/brainpy/BrainPy/pull/413 +* Many Updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/414 +* Update docs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/415 +* fix conflict by @yygf123 in https://github.com/brainpy/BrainPy/pull/416 +* add a new implementation of Dual Exponential Synapse model which can be aligned post. by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/417 +* Enable test when pull requests by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/418 +* Add random.seed() by @yygf123 in https://github.com/brainpy/BrainPy/pull/419 +* Remove windows CI because it always generates strange errors by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/420 +* Recent updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/421 +* upgrade Runner and Trainer for new style of ``DynamicalSystem.update()`` function by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/422 +* update docs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/424 +* fix ``lif`` model bugs and support two kinds of spike reset: ``soft`` and ``hard`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/423 +* rewrite old synapses with decomposed components by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/425 +* fix autograd bugs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/426 + +#### New Contributors +* @yygf123 made their first contribution in https://github.com/brainpy/BrainPy/pull/416 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.4.2...V2.4.3 + + + + + + + + +### Version 2.4.2 + + + +We are very excited to release this new version of BrainPy V2.4.2. In this new update, we cover several exciting features: +#### New Features +* Reorganize the model to decouple dynamics and communication. +* Add `brainpy.dyn` for dynamics models and `brainpy.dnn` for the ANN layer and connection structures. +* Supplement many docs for dedicated operators and common bugs of BrainPy. +* Fix many bugs. + +#### What's Changed +* [ANN] add more activation functions by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/379 +* Optimize Gaussian Decay initializer by @Routhleck in https://github.com/brainpy/BrainPy/pull/381 +* [update] new loss functions, surrograte base class, Array built-in functions by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/383 +* [parallelization] new module of ``brainpy.pnn`` for auto parallelization of brain models by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/385 +* [fix] fix the bug of loading states by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/388 +* [math] support `jax.disable_jit()` for debugging by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/389 +* [initialize] speed up ``brainpy.init.DOGDecay`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/390 +* [doc] fix doc build by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/391 +* Add deprecations for deprecated APIs or functions by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/393 +* [math] enable debugging for new style of transformations in BrainPy by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/392 +* [math] flow control updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/396 +* Test of rates by @shangyangli in https://github.com/brainpy/BrainPy/pull/386 +* Add math docs: NumPy-like operations and Dedicated operators by @c-xy17 in https://github.com/brainpy/BrainPy/pull/395 +* [doc] documentation about ``how to debug`` and ``common gotchas`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/397 +* Update requirements-doc.txt by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/399 +* debug (images not displayed) by @c-xy17 in https://github.com/brainpy/BrainPy/pull/400 +* Decouple dynamics and comminucations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/401 +* [fix] bugs of control flows by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/404 +* Test for channels, neurons and synapses. by @ztqakita in https://github.com/brainpy/BrainPy/pull/403 +* Implement function to visualize connection matrix by @Routhleck in https://github.com/brainpy/BrainPy/pull/405 +* Optimize GaussianProb by @Routhleck in https://github.com/brainpy/BrainPy/pull/406 +* [dyn] add reduce models, HH-type models and channels by @ztqakita in https://github.com/brainpy/BrainPy/pull/408 +* [dnn] add various linear layers by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/407 +* [delay] `VariableDelay` and `DataDelay` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/409 +* [dyn] add COBA examples using the interface of new `brainpy.dyn` module by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/410 +* [dyn] Update dyn.neurons docs and fix several bugs by @ztqakita in https://github.com/brainpy/BrainPy/pull/411 + +#### New Contributors +* @shangyangli made their first contribution in https://github.com/brainpy/BrainPy/pull/386 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.4.1...V2.4.2 + + + + +### Version 2.4.1 + + +#### New Features + +1. [math] Support the error report when modifying a `brainpy.math.Array` during compilation +2. [math] add `brainpy.math.event`, `brainpy.math.sparse` and `brainpy.math.jitconn` module, needs ``brainpylib >= 0.1.9`` +3. [interoperation] add apis and docs for `brainpy.layers.FromFlax` and `brainpy.layer.ToFlaxRNNCell` +4. [fix] Bug fixes: + - fix WilsonCowan bug + - fix `brainpy.connect.FixedProb` bug + - fix analysis jit bug + + + +#### What's Changed +* Update structures by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/364 +* create blocksparse matrix matrix multiplication opearator by @Routhleck in https://github.com/brainpy/BrainPy/pull/365 +* commit by @grysgreat in https://github.com/brainpy/BrainPy/pull/367 +* Fix bugs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/368 +* [math] update dedicated operators by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/370 +* fix bugs by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/371 +* [bug] fix merging bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/372 +* [structure] update package structure by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/369 +* [test] update csrmv tests by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/373 +* [interoperation] add apis and docs for `brainpy.layers.FromFlax` and `brainpy.layer.ToFlaxRNNCell` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/374 +* [doc] update documentation by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/375 +* [bug] fix `brainpy.connect.FixedProb` bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/376 +* [bug] fix analysis jit bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/377 +* update brainpylib requirements by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/378 + +#### New Contributors +* @Routhleck made their first contribution in https://github.com/brainpy/BrainPy/pull/365 +* @grysgreat made their first contribution in https://github.com/brainpy/BrainPy/pull/367 + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.4.0...V2.4.1 + + + + + +### Version 2.4.0 + +This branch of releases (``brainpy==2.4.x``) are going to support the large-scale modeling for brain dynamics. + +As the start, this release provides support for automatic object-oriented (OO) transformations. + + +#### What's New + + +1. Automatic OO transformations on longer need to take ``dyn_vars`` or ``child_objs`` information. + These transformations are capable of automatically inferring the underlying dynamical variables. + Specifically, they include: + + - ``brainpy.math.grad`` and other autograd functionalities + - ``brainpy.math.jit`` + - ``brainpy.math.for_loop`` + - ``brainpy.math.while_loop`` + - ``brainpy.math.ifelse`` + - ``brainpy.math.cond`` + +2. Update documentation +3. Fix several bugs + +#### What's Changed +* reorganize operators in `brainpy.math` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/357 +* Automatic transformations any function/object using `brainpy.math.Variable` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/358 +* New OO transforms support ``jax.disable_jit`` mode by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/359 +* [oo transform] Enable new style of jit transformation to support `static_argnums` and `static_argnames` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/360 +* [documentation] update documentation to brainpy>=2.4.0 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/361 + + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.3.8...V2.4.0 + + + + + + + +### Version 2.3.8 + + +This release continues to add support for improving the usability of BrainPy. + + +#### New Features + + +1. New data structures for object-oriented transformations. + - ``NodeList`` and ``NodeDict`` for a list/tuple/dict of ``BrainPyObject`` instances. + - ``ListVar`` and ``DictVar`` for a list/tuple/dict of brainpy data. +2. `Clip` transformation for brainpy initializers. +3. All ``brainpylib`` operators are accessible in ``brainpy.math`` module. Especially there are some dedicated operators for scaling up the million-level neuron networks. For an example, see example in [Simulating 1-million-neuron networks with 1GB GPU memory](https://brainpy-examples.readthedocs.io/en/latest/large_scale_modeling/EI_net_with_1m_neurons.html) +5. Enable monitoring GPU models on CPU when setting ``DSRunner(..., memory_efficient=True)``. This setting can usually reduce so much memory usage. +6. ``brainpylib`` wheels on the Linux platform support the GPU operators. Users can install GPU version of ``brainpylib`` (require ``brainpylib>=0.1.7``) directly by ``pip install brainpylib``. @ztqakita + +#### What's Changed +* Fix bugs and add more variable structures: `ListVar` and `DictVar` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/345 +* add CI for testing various models by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/346 +* Update docs and tests by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/347 +* Fix `Runner(jit=False)`` bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/348 +* Compatible with jax>=0.4.7 by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/349 +* Updates by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/350 +* reconstruct BrainPy by merging brainpylib by @ztqakita in https://github.com/brainpy/BrainPy/pull/351 +* Intergate brainpylib operators into brainpy by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/352 +* fix `brainpylib` call bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/354 +* Enable memory-efficient ``DSRunner`` by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/355 +* fix `Array` transform bug by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/356 + + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.3.7...V2.3.8 + + + +### Version 2.3.7 + +- Fix bugs on population models in ``brainpy.rate`` module +- Fix bug on ``brainpy.LoopOverTime`` +- Add more synaptic models including DualExpoenetial model and Alpha model in ``brainpy.experimental`` module +- Support call a module through right shift, such as ``data >> module1 >> module2`` + + +### Version 2.3.6 + +This release continues to add support for brain-inspired computation. + + +#### New Features + +##### More flexible customization of surrogate gradient functions. + +- brainpy.math.surrogate.Sigmoid +- brainpy.math.surrogate.PiecewiseQuadratic +- brainpy.math.surrogate.PiecewiseExp +- brainpy.math.surrogate.SoftSign +- brainpy.math.surrogate.Arctan +- brainpy.math.surrogate.NonzeroSignLog +- brainpy.math.surrogate.ERF +- brainpy.math.surrogate.PiecewiseLeakyRelu +- brainpy.math.surrogate.SquarewaveFourierSeries +- brainpy.math.surrogate.S2NN +- brainpy.math.surrogate.QPseudoSpike +- brainpy.math.surrogate.LeakyRelu +- brainpy.math.surrogate.LogTailedRelu +- brainpy.math.surrogate.ReluGrad +- brainpy.math.surrogate.GaussianGrad +- brainpy.math.surrogate.InvSquareGrad +- brainpy.math.surrogate.MultiGaussianGrad +- brainpy.math.surrogate.SlayerGrad + +##### Fix bugs + +- ``brainpy.LoopOverTime`` + + + + + +### Version 2.3.5 + + + +This release continues to add support for brain-inspired computation. + + +#### New Features + + +##### 1. ``brainpy.share`` for sharing data across submodules + +In this release, we abstract the shared data as a ``brainpy.share`` object. + +This object together with ``brainpy.Delay`` we will introduce below constitutes the support that enables us to define SNN models like ANN ones. + + +##### 2. ``brainpy.Delay`` for delay processing + +``Delay`` is abstracted as a dynamical system, which can be updated/retrieved by users. + +```python +import brainpy as bp + +class EINet(bp.DynamicalSystemNS): + def __init__(self, scale=1.0, e_input=20., i_input=20., delay=None): + super().__init__() + + self.bg_exc = e_input + self.bg_inh = i_input + + # network size + num_exc = int(3200 * scale) + num_inh = int(800 * scale) + + # neurons + pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), input_var=False) + self.E = bp.neurons.LIF(num_exc, **pars) + self.I = bp.neurons.LIF(num_inh, **pars) + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.E2E = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.E.size, post=self.E.size), + g_max=we, tau=5., out=bp.experimental.COBA(E=0.) + ) + self.E2I = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.E.size, post=self.I.size, ), + g_max=we, tau=5., out=bp.experimental.COBA(E=0.) + ) + self.I2E = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.I.size, post=self.E.size), + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) + ) + self.I2I = bp.experimental.Exponential( + bp.conn.FixedProb(0.02, pre=self.I.size, post=self.I.size), + g_max=wi, tau=10., out=bp.experimental.COBA(E=-80.) + ) + self.delayE = bp.Delay(self.E.spike, entries={'E': delay}) + self.delayI = bp.Delay(self.I.spike, entries={'I': delay}) + + def update(self): + e_spike = self.delayE.at('E') + i_spike = self.delayI.at('I') + e_inp = self.E2E(e_spike, self.E.V) + self.I2E(i_spike, self.E.V) + self.bg_exc + i_inp = self.I2I(i_spike, self.I.V) + self.E2I(e_spike, self.I.V) + self.bg_inh + self.delayE(self.E(e_inp)) + self.delayI(self.I(i_inp)) + +``` + + + +##### 3. ``brainpy.checkpoints.save_pytree`` and ``brainpy.checkpoints.load_pytree`` for saving/loading target from the filename + +Now we can directly use ``brainpy.checkpoints.save_pytree`` to save a network state into the file path we specified. + +Similarly, we can use ``brainpy.checkpoints.load_pytree`` to load states from the given file path. + + +##### 4. More ANN layers + + +- brainpy.layers.ConvTranspose1d +- brainpy.layers.ConvTranspose2d +- brainpy.layers.ConvTranspose3d +- brainpy.layers.Conv1dLSTMCell +- brainpy.layers.Conv2dLSTMCell +- brainpy.layers.Conv3dLSTMCell + + +##### 5. More compatible dense operators + +PyTorch operators: + +- brainpy.math.Tensor +- brainpy.math.flatten +- brainpy.math.cat +- brainpy.math.abs +- brainpy.math.absolute +- brainpy.math.acos +- brainpy.math.arccos +- brainpy.math.acosh +- brainpy.math.arccosh +- brainpy.math.add +- brainpy.math.addcdiv +- brainpy.math.addcmul +- brainpy.math.angle +- brainpy.math.asin +- brainpy.math.arcsin +- brainpy.math.asinh +- brainpy.math.arcsin +- brainpy.math.atan +- brainpy.math.arctan +- brainpy.math.atan2 +- brainpy.math.atanh + + +TensorFlow operators: + +- brainpy.math.concat +- brainpy.math.reduce_sum +- brainpy.math.reduce_max +- brainpy.math.reduce_min +- brainpy.math.reduce_mean +- brainpy.math.reduce_all +- brainpy.math.reduce_any +- brainpy.math.reduce_logsumexp +- brainpy.math.reduce_prod +- brainpy.math.reduce_std +- brainpy.math.reduce_variance +- brainpy.math.reduce_euclidean_norm +- brainpy.math.unsorted_segment_sqrt_n +- brainpy.math.segment_mean +- brainpy.math.unsorted_segment_sum +- brainpy.math.unsorted_segment_prod +- brainpy.math.unsorted_segment_max +- brainpy.math.unsorted_segment_min +- brainpy.math.unsorted_segment_mean +- brainpy.math.segment_sum +- brainpy.math.segment_prod +- brainpy.math.segment_max +- brainpy.math.segment_min +- brainpy.math.clip_by_value +- brainpy.math.cast + + +##### Others + +- Remove the hard requirements of ``brainpylib`` and ``numba``. + + + + +### Version 2.3.4 + + +This release mainly focuses on the compatibility with other frameworks: + +1. Fix Jax import error when `jax>=0.4.2` +2. Backward compatibility of `brainpy.dyn` module +3. Start to implement and be compatible with operators in pytorch and tensorflow, so that user's pytorch/tensorflow models can be easily migrated to brainpy + + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.3.3...V2.3.4 + + + + +### Version 2.3.3 + + +Improve backward compatibility: + +- monitors and inputs in ``DSRunner`` +- models in ``brainpy.dyn`` +- constants and function in ``brainpy.analysis`` + + +### Version 2.3.2 + +This release (under the branch of ``brainpy=2.3.x``) continues to add support for brain-inspired computation. + + +#### New Features + + +##### 1. New package structure for stable API release + +Unstable APIs are all hosted in ``brainpy._src`` module. +Other APIs are stable and will be maintained for a long time. + + +##### 2. New schedulers + +- `brainpy.optim.CosineAnnealingWarmRestarts` +- `brainpy.optim.CosineAnnealingLR` +- `brainpy.optim.ExponentialLR` +- `brainpy.optim.MultiStepLR` +- `brainpy.optim.StepLR` + + +##### 3. Others + +- support `static_argnums` in `brainpy.math.jit` +- fix bugs of `reset_state()` and `clear_input()` in `brainpy.channels` +- fix jit error checking + + + + + + +### Version 2.3.1 + +This release (under the release branch of ``brainpy=2.3.x``) continues to add supports for brain-inspired computation. + + + +```python +import brainpy as bp +import brainpy.math as bm +``` + + + +#### Backwards Incompatible Changes + + + +###### 1. Error: module 'brainpy' has no attribute 'datasets' + +``brainpy.datasets`` module is now published as an independent package ``brainpy_datasets``. + +Please change your dataset access from + +```python +bp.datasets.xxxxx +``` + +to + +```python +import brainpy_datasets as bp_data + +bp_data.chaos.XXX +bp_data.vision.XXX +``` + +For a chaotic data series, + +```python +# old version +data = bp.datasets.double_scroll_series(t_warmup + t_train + t_test, dt=dt) +x_var = data['x'] +y_var = data['y'] +z_var = data['z'] + +# new version +data = bd.chaos.DoubleScrollEq(t_warmup + t_train + t_test, dt=dt) +x_var = data.xs +y_var = data.ys +z_var = data.zs +``` + +For a vision dataset, + +```python +# old version +dataset = bp.datasets.FashionMNIST(root, train=True, download=True) + +# new version +dataset = bd.vision.FashionMNIST(root, split='train', download=True) +``` + + + +###### 2. Error: DSTrainer must receive an instance with BatchingMode + +This error will happen when using ``brainpy.OnlineTrainer`` , ``brainpy.OfflineTrainer``, ``brainpy.BPTT`` , ``brainpy.BPFF``. + +From version 2.3.1, BrainPy explicitly consider the computing mode of each model. For trainers, all training target should be a model with ``BatchingMode`` or ``TrainingMode``. + +If you are training model with ``OnlineTrainer`` or ``OfflineTrainer``, + +```python +# old version +class NGRC(bp.DynamicalSystem): + def __init__(self, num_in): + super(NGRC, self).__init__() + self.r = bp.layers.NVAR(num_in, delay=2, order=3) + self.di = bp.layers.Dense(self.r.num_out, num_in) + + def update(self, sha, x): + di = self.di(sha, self.r(sha, x)) + return x + di + + +# new version +bm.set_enviroment(mode=bm.batching_mode) + +class NGRC(bp.DynamicalSystem): + def __init__(self, num_in): + super(NGRC, self).__init__() + self.r = bp.layers.NVAR(num_in, delay=2, order=3) + self.di = bp.layers.Dense(self.r.num_out, num_in, mode=bm.training_mode) + + def update(self, sha, x): + di = self.di(sha, self.r(sha, x)) + return x + di +``` + + If you are training models with ``BPTrainer``, adding the following line at the top of the script, + +```python +bm.set_enviroment(mode=bm.training_mode) +``` + + + +###### 3. Error: inputs_are_batching is no longer supported. + +This is because if the training target is in ``batching`` mode, this has already indicated that the inputs should be batching. + +Simple remove the ``inputs_are_batching`` from your functional call of ``.predict()`` will solve the issue. + + + + + +#### New Features + + + +##### 1. ``brainpy.math`` module upgrade + +###### ``brainpy.math.surrogate`` module for surrogate gradient functions. + +Currently, we support + +- `brainpy.math.surrogate.arctan` +- `brainpy.math.surrogate.erf` +- `brainpy.math.surrogate.gaussian_grad` +- `brainpy.math.surrogate.inv_square_grad` +- `brainpy.math.surrogate.leaky_relu` +- `brainpy.math.surrogate.log_tailed_relu` +- `brainpy.math.surrogate.multi_gaussian_grad` +- `brainpy.math.surrogate.nonzero_sign_log` +- `brainpy.math.surrogate.one_input` +- `brainpy.math.surrogate.piecewise_exp` +- `brainpy.math.surrogate.piecewise_leaky_relu` +- `brainpy.math.surrogate.piecewise_quadratic` +- `brainpy.math.surrogate.q_pseudo_spike` +- `brainpy.math.surrogate.relu_grad` +- `brainpy.math.surrogate.s2nn` +- `brainpy.math.surrogate.sigmoid` +- `brainpy.math.surrogate.slayer_grad` +- `brainpy.math.surrogate.soft_sign` +- `brainpy.math.surrogate.squarewave_fourier_series` + + + +###### New transformation function ``brainpy.math.to_dynsys`` + +New transformation function ``brainpy.math.to_dynsys`` supports to transform a pure Python function into a ``DynamicalSystem``. This will be useful when running a `DynamicalSystem` with arbitrary customized inputs. + +```python +import brainpy.math as bm + +hh = bp.neurons.HH(1) + +@bm.to_dynsys(child_objs=hh) +def run_hh(tdi, x=None): + if x is not None: + hh.input += x + +runner = bp.DSRunner(run_hhh, monitors={'v': hh.V}) +runner.run(inputs=bm.random.uniform(3, 6, 1000)) +``` + + + +###### Default data types + +Default data types `brainpy.math.int_`, `brainpy.math.float_` and `brainpy.math.complex_` are initialized according to the default `x64` settings. Then, these data types can be set or get by `brainpy.math.set_*` or `brainpy.math.get_*` syntaxes. + +Take default integer type ``int_`` as an example, + +```python +# set the default integer type +bm.set_int_(jax.numpy.int64) + +# get the default integer type +a1 = bm.asarray([1], dtype=bm.int_) +a2 = bm.asarray([1], dtype=bm.get_int()) # equivalent +``` + +Default data types are changed according to the `x64` setting of JAX. For instance, + +```python +bm.enable_x64() +assert bm.int_ == jax.numpy.int64 +bm.disable_x64() +assert bm.int_ == jax.numpy.int32 +``` + +``brainpy.math.float_`` and ``brainpy.math.complex_`` behaves similarly with ``brainpy.math.int_``. + + + +###### Environment context manager + +This release introduces a new concept ``computing environment`` in BrainPy. Computing environment is a default setting for current computation jobs, including the default data type (``int_``, ``float_``, ``complex_``), the default numerical integration precision (``dt``), the default computing mode (``mode``). All models, arrays, and computations using the default setting will be carried out under the environment setting. + +Users can set a default environment through + +```python +brainpy.math.set_environment(mode, dt, x64) +``` + +However, ones can also construct models or perform computation through a temporal environment context manager, this can be implemented through: + +```python +# constructing a HH model with dt=0.1 and x64 precision +with bm.environment(mode, dt=0.1, x64=True): + hh1 = bp.neurons.HH(1) + +# constructing a HH model with dt=0.05 and x32 precision +with bm.environment(mode, dt=0.05, x64=False): + hh2 = bp.neuron.HH(1) +``` + +Usually, users construct models for either brain-inspired computing (``training mode``) or brain simulation (``nonbatching mode``), therefore, there are shortcut context manager for setting a training environment or batching environment: + +```python +with bm.training_environment(dt, x64): + pass + +with bm.batching_environment(dt, x64): + pass +``` + + + +##### 2. ``brainpy.dyn`` module + + + +###### ``brainpy.dyn.transfom`` module for transforming a ``DynamicalSystem`` instance to a callable ``BrainPyObject``. + +Specifically, we provide + +- `LoopOverTime` for unrolling a dynamical system over time. +- `NoSharedArg` for removing the dependency of shared arguments. + + + + + +##### 3. Running supports in BrainPy + + + +###### All ``brainpy.Runner`` now are subclasses of ``BrainPyObject`` + +This means that all ``brainpy.Runner`` can be used as a part of the high-level program or transformation. + + + +###### Enable the continuous running of a differential equation (ODE, SDE, FDE, DDE, etc.) with `IntegratorRunner`. + +For example, + +```python +import brainpy as bp + +# differential equation +a, b, tau = 0.7, 0.8, 12.5 +dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext +dw = lambda w, t, V: (V + a - b * w) / tau +fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1) + +# differential integrator runner +runner = bp.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.]) + +# run 1 +Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 200, 200], return_length=True) +runner.run(duration, dyn_args=dict(Iext=Iext)) +bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V') + +# run 2 +Iext, duration = bp.inputs.section_input([0.5], [200], return_length=True) +runner.run(duration, dyn_args=dict(Iext=Iext)) +bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V-run2', show=True) + +``` + + + +###### Enable call a customized function during fitting of ``brainpy.BPTrainer``. + +This customized function (provided through ``fun_after_report``) will be useful to save a checkpoint during the training. For instance, + +```python +class CheckPoint: + def __init__(self, path='path/to/directory/'): + self.max_acc = 0. + self.path = path + + def __call__(self, idx, metrics, phase): + if phase == 'test' and metrics['acc'] > self.max_acc: + self.max_acc = matrics['acc'] + bp.checkpoints.save(self.path, net.state_dict(), idx) + +trainer = bp.BPTT() +trainer.fit(..., fun_after_report=CheckPoint()) +``` + + + +###### Enable data with ``data_first_axis`` format when predicting or fitting in a ``brainpy.DSRunner`` and ``brainpy.DSTrainer``. + +Previous version of BrainPy only supports data with the batch dimension at the first axis. Currently, ``brainpy.DSRunner`` and ``brainpy.DSTrainer`` can support the data with the time dimension at the first axis. This can be set through ``data_first_axis='T'`` when initializing a runner or trainer. + +```python +runner = bp.DSRunner(..., data_first_axis='T') +trainer = bp.DSTrainer(..., data_first_axis='T') +``` + + + +##### 4. Utility in BrainPy + + + +###### ``brainpy.encoding`` module for encoding rate values into spike trains + + Currently, we support + +- `brainpy.encoding.LatencyEncoder` +- `brainpy.encoding.PoissonEncoder` +- `brainpy.encoding.WeightedPhaseEncoder` + + + +###### ``brainpy.checkpoints`` module for model state serialization. + +This version of BrainPy supports to save a checkpoint of the model into the physical disk. Inspired from the Flax API, we provide the following checkpoint APIs: + +- ``brainpy.checkpoints.save()`` for saving a checkpoint of the model. +- ``brainpy.checkpoints.multiprocess_save()`` for saving a checkpoint of the model in multi-process environment. +- ``brainpy.checkpoints.load()`` for loading the last or best checkpoint from the given checkpoint path. +- ``brainpy.checkpoints.load_latest()`` for retrieval the path of the latest checkpoint in a directory. + + + + + +#### Deprecations + + + +##### 1. Deprecations in the running supports of BrainPy + +###### ``func_monitors`` is no longer supported in all ``brainpy.Runner`` subclasses. + +We will remove its supports since version 2.4.0. Instead, monitoring with a dict of callable functions can be set in ``monitors``. For example, + + + ```python + # old version + + runner = bp.DSRunner(model, + monitors={'sps': model.spike, 'vs': model.V}, + func_monitors={'sp10': model.spike[10]}) + ``` + + ```python + # new version + runner = bp.DSRunner(model, + monitors={'sps': model.spike, + 'vs': model.V, + 'sp10': model.spike[10]}) + ``` + + + +###### ``func_inputs`` is no longer supported in all ``brainpy.Runner`` subclasses. + + Instead, giving inputs with a callable function should be done with ``inputs``. + +```python +# old version + +net = EINet() + +def f_input(tdi): + net.E.input += 10. + +runner = bp.DSRunner(net, fun_inputs=f_input, inputs=('I.input', 10.)) +``` + +```python +# new version + +def f_input(tdi): + net.E.input += 10. + net.I.input += 10. +runner = bp.DSRunner(net, inputs=f_input) +``` + + + +###### ``inputs_are_batching`` is deprecated. + +``inputs_are_batching`` is deprecated in ``predict()``/``.run()`` of all ``brainpy.Runner`` subclasses. + + + +###### ``args`` and ``dyn_args`` are now deprecated in ``IntegratorRunner``. + +Instead, users should specify ``args`` and ``dyn_args`` when using ``IntegratorRunner.run()`` function. + +```python +dV = lambda V, t, w, I: V - V * V * V / 3 - w + I +dw = lambda w, t, V, a, b: (V + a - b * w) / 12.5 +integral = bp.odeint(bp.JointEq([dV, dw]), method='exp_auto') + +# old version +runner = bp.IntegratorRunner( + integral, + monitors=['V', 'w'], + inits={'V': bm.random.rand(10), 'w': bm.random.normal(size=10)}, + args={'a': 1., 'b': 1.}, # CHANGE + dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}, # CHANGE +) +runner.run(100.,) + +``` + +```python +# new version +runner = bp.IntegratorRunner( + integral, + monitors=['V', 'w'], + inits={'V': bm.random.rand(10), 'w': bm.random.normal(size=10)}, +) +runner.run(100., + args={'a': 1., 'b': 1.}, + dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}) +``` + + + +##### 2. Deprecations in ``brainpy.math`` module + +###### `ditype()` and `dftype()` are deprecated. + +`brainpy.math.ditype()` and `brainpy.math.dftype()` are deprecated. Using `brainpy.math.int_` and `brainpy.math.float()` instead. + + + +###### ``brainpy.modes`` module is now moved into ``brainpy.math`` + +The correspondences are listed as the follows: + +- ``brainpy.modes.Mode`` => ``brainpy.math.Mode`` +- ``brainpy.modes.NormalMode `` => ``brainpy.math.NonBatchingMode`` +- ``brainpy.modes.BatchingMode `` => ``brainpy.math.BatchingMode`` +- ``brainpy.modes.TrainingMode `` => ``brainpy.math.TrainingMode`` +- ``brainpy.modes.normal `` => ``brainpy.math.nonbatching_mode`` +- ``brainpy.modes.batching `` => ``brainpy.math.batching_mode`` +- ``brainpy.modes.training `` => ``brainpy.math.training_mode`` + + + + + + +### Version 2.3.0 + +This branch of releases aims to provide a unified computing framework for brain simulation and brain-inspired computing. + +#### New features + +1. ``brainpy.BPTT`` supports `train_data` and `test_data` with general Python iterators. For instance, one can train a model with PyTorch dataloader or TensorFlow datasets. + +```python +import torchvision +from torch.utils.data import DataLoader +data = torchvision.datasets.CIFAR10("./CIFAR10", train=False, transform=torchvision.transforms.ToTensor()) +loader = DataLoader(dataset=data, batch_size=4, shuffle=True, num_workers=0, drop_last=False) + +# any generator can be used for train_data or test_data +trainer = bp.BPTT() +trainer.fit(loader) +``` + +2. Consolidated object-oriented transformation in ``brainpy.math.object_transform`` module. All brainpy transformations generate a new ``BrainPyObject`` instance so that objects in brainpy can be composed hierarchically. ``brainpy.math.to_object()`` transformation transforms a pure Python function into a ``BrainPyObject``. + +3. New [documentation](https://brainpy.readthedocs.io/en/latest/tutorial_math/brainpy_transform_concept.html) is currently online for introducing the consolidated BrainPy concept of object-oriented transformation. + +4. Change ``brainpy.math.JaxArray`` to ``brainpy.math.Array``. + + + + +#### Deprecations + +1. ``brainpy.datasets`` module is no longer supported. New APIs will be moved into [``brainpy-datasets`` package](https://github.com/brainpy/datasets). +2. ``brainpy.train.BPTT`` no longer support to receive the train data `[X, Y]`. Instead, users should provide a data generator such like ``pytorch`` dataset or ``tensorflow`` dataset. +4. The update function of ``brainpy.math.TimeDealy`` does not support receiving a `time` index. Instead, one can update the new data by directly using ``TimeDealy.update(data)`` instead of `TimeDealy.update(time, data)`. +5. Fix the monitoring error of delay differential equations with ``brainpy.integrators.IntegratorRunner``. + +#### Bug Fixes + +1. Fix the bug on ``One2One`` connection. +2. Fix the bug in ``eprop`` example. +3. Fix `ij2csr` transformation error. +4. Fix test bugs + +#### What's Changed +* fix eprop example error by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/305 +* minor updates on API and DOC by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/306 +* Add new optimizers by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/307 +* add documentation of for random number generation by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/308 +* consolidate the concept of OO transformation by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/309 +* Upgrade documetations by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/310 +* Ready for publish by @chaoming0625 in https://github.com/brainpy/BrainPy/pull/311 + + +**Full Changelog**: https://github.com/brainpy/BrainPy/compare/V2.2.4.0...V2.3.0 + + +## brainpy 2.2.x + +BrainPy 2.2.x is a complete re-design of the framework, tackling the +shortcomings of brainpy 2.1.x generation, effectively bringing it to +research needs and standards. + + + + +### Version 2.2.4 + +This release has updated many functionalities and fixed several bugs in BrainPy. + +#### New Features + +1. More ANN layers, including ``brainpy.layers.Flatten`` and ``brainpy.layers.Activation``. +2. Optimized connection building for ``brainpy.connect`` module. +3. cifar dataset. +4. Enhanced API and Doc for parallel simulations via ``brainpy.running.cpu_ordered_parallel``, ``brainpy.running.cpu_unordered_parallel``, ``brainpy.running.jax_vectorize_map`` and ``brainpy.running.jax_parallelize_map``. + + +#### What's Changed +* add Activation and Flatten class by @LuckyHFC in https://github.com/PKU-NIP-Lab/BrainPy/pull/291 +* optimizes the connect time when using gpu by @MamieZhu in https://github.com/PKU-NIP-Lab/BrainPy/pull/293 +* datasets::vision: add cifar dataset by @hbelove in https://github.com/PKU-NIP-Lab/BrainPy/pull/292 +* fix #294: remove VariableView in `dyn_vars` of a runner by @chaoming0625 in https://github.com/PKU-NIP-Lab/BrainPy/pull/295 +* update issue template by @chaoming0625 in https://github.com/PKU-NIP-Lab/BrainPy/pull/296 +* add multiprocessing functions for batch running of BrainPy functions by @chaoming0625 in https://github.com/PKU-NIP-Lab/BrainPy/pull/298 +* upgrade connection apis by @chaoming0625 in https://github.com/PKU-NIP-Lab/BrainPy/pull/299 +* fix #300: update parallelization api documentation by @chaoming0625 in https://github.com/PKU-NIP-Lab/BrainPy/pull/302 +* update doc by @chaoming0625 in https://github.com/PKU-NIP-Lab/BrainPy/pull/303 + +#### New Contributors +* @LuckyHFC made their first contribution in https://github.com/PKU-NIP-Lab/BrainPy/pull/291 +* @MamieZhu made their first contribution in https://github.com/PKU-NIP-Lab/BrainPy/pull/293 +* @hbelove made their first contribution in https://github.com/PKU-NIP-Lab/BrainPy/pull/292 + +**Full Changelog**: https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.2.3.6...V2.2.4 + + + + +### Version 2.2.1 (2022.09.09) + +This release fixes bugs found in the codebase and improves the usability +and functions of BrainPy. + +#### Bug fixes + +1. Fix the bug of operator customization in `brainpy.math.XLACustomOp` + and `brainpy.math.register_op`. Now, it supports operator + customization by using NumPy and Numba interface. For instance, + +``` python +import brainpy.math as bm + +def abs_eval(events, indices, indptr, post_val, values): + return post_val + +def con_compute(outs, ins): + post_val = outs + events, indices, indptr, _, values = ins + for i in range(events.size): + if events[i]: + for j in range(indptr[i], indptr[i + 1]): + index = indices[j] + old_value = post_val[index] + post_val[index] = values + old_value + +event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute) +``` + +1. Fix the bug of `brainpy.tools.DotDict`. Now, it is compatible with + the transformations of JAX. For instance, + +``` python +import brainpy as bp +from jax import vmap + +@vmap +def multiple_run(I): + hh = bp.neurons.HH(1) + runner = bp.dyn.DSRunner(hh, inputs=('input', I), numpy_mon_after_run=False) + runner.run(100.) + return runner.mon + +mon = multiple_run(bp.math.arange(2, 10, 2)) +``` + +#### New features + +1. Add numpy operators `brainpy.math.mat`, `brainpy.math.matrix`, + `brainpy.math.asmatrix`. +2. Improve translation rules of brainpylib operators, improve its + running speeds. +3. Support `DSView` of `DynamicalSystem` instance. Now, it supports + defining models with a slice view of a DS instance. For example, + +``` python +import brainpy as bp +import brainpy.math as bm + + +class EINet_V2(bp.dyn.Network): + def __init__(self, scale=1.0, method='exp_auto'): + super(EINet_V2, self).__init__() + + # network size + num_exc = int(3200 * scale) + num_inh = int(800 * scale) + + # neurons + self.N = bp.neurons.LIF(num_exc + num_inh, + V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + method=method, V_initializer=bp.initialize.Normal(-55., 2.)) + + # synapses + we = 0.6 / scale # excitatory synaptic weight (voltage) + wi = 6.7 / scale # inhibitory synaptic weight + self.Esyn = bp.synapses.Exponential(pre=self.N[:num_exc], post=self.N, + conn=bp.connect.FixedProb(0.02), + g_max=we, tau=5., + output=bp.synouts.COBA(E=0.), + method=method) + self.Isyn = bp.synapses.Exponential(pre=self.N[num_exc:], post=self.N, + conn=bp.connect.FixedProb(0.02), + g_max=wi, tau=10., + output=bp.synouts.COBA(E=-80.), + method=method) + +net = EINet_V2(scale=1., method='exp_auto') +# simulation +runner = bp.dyn.DSRunner( + net, + monitors={'spikes': net.N.spike}, + inputs=[(net.N.input, 20.)] + ) +runner.run(100.) + +# visualization +bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True) +``` + +### Version 2.2.0 (2022.08.12) + +This release has provided important improvements for BrainPy, including +usability, speed, functions, and others. + +#### Backwards Incompatible changes + +1. `brainpy.nn` module is no longer supported and has been removed + since version 2.2.0. Instead, users should use `brainpy.train` + module for the training of BP algorithms, online learning, or + offline learning algorithms, and `brainpy.algorithms` module for + online / offline training algorithms. +2. The `update()` function for the model definition has been changed: + +``` +>>> # 2.1.x +>>> +>>> import brainpy as bp +>>> +>>> class SomeModel(bp.dyn.DynamicalSystem): +>>> def __init__(self, ): +>>> ...... +>>> def update(self, t, dt): +>>> pass +>>> # 2.2.x +>>> +>>> import brainpy as bp +>>> +>>> class SomeModel(bp.dyn.DynamicalSystem): +>>> def __init__(self, ): +>>> ...... +>>> def update(self, tdi): +>>> t, dt = tdi.t, tdi.dt +>>> pass +``` + +where `tdi` can be defined with other names, like `sha`, to represent +the shared argument across modules. + +#### Deprecations + +1. `brainpy.dyn.xxx (neurons)` and `brainpy.dyn.xxx (synapse)` are no + longer supported. Please use `brainpy.neurons`, `brainpy.synapses` + modules. +2. `brainpy.running.monitor` has been removed. +3. `brainpy.nn` module has been removed. + +#### New features + +1. `brainpy.math.Variable` receives a `batch_axis` setting to represent + the batch axis of the data. + +``` +>>> import brainpy.math as bm +>>> a = bm.Variable(bm.zeros((1, 4, 5)), batch_axis=0) +>>> a.value = bm.zeros((2, 4, 5)) # success +>>> a.value = bm.zeros((1, 2, 5)) # failed +MathError: The shape of the original data is (2, 4, 5), while we got (1, 2, 5) with batch_axis=0. +``` + +2. `brainpy.train` provides `brainpy.train.BPTT` for back-propagation + algorithms, `brainpy.train.Onlinetrainer` for online training + algorithms, `brainpy.train.OfflineTrainer` for offline training + algorithms. +3. `brainpy.Base` class supports `_excluded_vars` setting to ignore + variables when retrieving variables by using `Base.vars()` method. + +``` +>>> class OurModel(bp.Base): +>>> _excluded_vars = ('a', 'b') +>>> def __init__(self): +>>> super(OurModel, self).__init__() +>>> self.a = bm.Variable(bm.zeros(10)) +>>> self.b = bm.Variable(bm.ones(20)) +>>> self.c = bm.Variable(bm.random.random(10)) +>>> +>>> model = OurModel() +>>> model.vars().keys() +dict_keys(['OurModel0.c']) +``` + +4. `brainpy.analysis.SlowPointFinder` supports directly analyzing an + instance of `brainpy.dyn.DynamicalSystem`. + +``` +>>> hh = bp.neurons.HH(1) +>>> finder = bp.analysis.SlowPointFinder(hh, target_vars={'V': hh.V, 'm': hh.m, 'h': hh.h, 'n': hh.n}) +``` + +5. `brainpy.datasets` supports MNIST, FashionMNIST, and other datasets. +6. Supports defining conductance-based neuron models\`\`. + +``` +>>> class HH(bp.dyn.CondNeuGroup): +>>> def __init__(self, size): +>>> super(HH, self).__init__(size) +>>> +>>> self.INa = channels.INa_HH1952(size, ) +>>> self.IK = channels.IK_HH1952(size, ) +>>> self.IL = channels.IL(size, E=-54.387, g_max=0.03) +``` + +7. `brainpy.layers` module provides commonly used models for DNN and + reservoir computing. +8. Support composable definition of synaptic models by using + `TwoEndConn`, `SynOut`, `SynSTP` and `SynLTP`. + +``` +>>> bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob), +>>> g_max=0.03 / scale, tau=5, +>>> output=bp.synouts.COBA(E=0.), +>>> stp=bp.synplast.STD()) +``` + +9. Provide commonly used surrogate gradient function for spiking + generation, including + - `brainpy.math.spike_with_sigmoid_grad` + - `brainpy.math.spike_with_linear_grad` + - `brainpy.math.spike_with_gaussian_grad` + - `brainpy.math.spike_with_mg_grad` +10. Provide shortcuts for GPU memory management via + `brainpy.math.disable_gpu_memory_preallocation()`, and + `brainpy.math.clear_buffer_memory()`. + +#### What\'s Changed + +- fix [#207](https://github.com/PKU-NIP-Lab/BrainPy/issues/207): + synapses update first, then neurons, finally delay variables by + [\@chaoming0625](https://github.com/chaoming0625) in + [#219](https://github.com/PKU-NIP-Lab/BrainPy/pull/219) +- docs: add logos by [\@ztqakita](https://github.com/ztqakita) in + [#218](https://github.com/PKU-NIP-Lab/BrainPy/pull/218) +- Add the biological NMDA model by + [\@c-xy17](https://github.com/c-xy17) in + [#221](https://github.com/PKU-NIP-Lab/BrainPy/pull/221) +- docs: fix mathjax problem by + [\@ztqakita](https://github.com/ztqakita) in + [#222](https://github.com/PKU-NIP-Lab/BrainPy/pull/222) +- Add the parameter R to the LIF model by + [\@c-xy17](https://github.com/c-xy17) in + [#224](https://github.com/PKU-NIP-Lab/BrainPy/pull/224) +- new version of brainpy: V2.2.0-rc1 by + [\@chaoming0625](https://github.com/chaoming0625) in + [#226](https://github.com/PKU-NIP-Lab/BrainPy/pull/226) +- update training apis by + [\@chaoming0625](https://github.com/chaoming0625) in + [#227](https://github.com/PKU-NIP-Lab/BrainPy/pull/227) +- Update quickstart and the analysis module by + [\@c-xy17](https://github.com/c-xy17) in + [#229](https://github.com/PKU-NIP-Lab/BrainPy/pull/229) +- Eseential updates for montors, analysis, losses, and examples by + [\@chaoming0625](https://github.com/chaoming0625) in + [#230](https://github.com/PKU-NIP-Lab/BrainPy/pull/230) +- add numpy op tests by [\@ztqakita](https://github.com/ztqakita) in + [#231](https://github.com/PKU-NIP-Lab/BrainPy/pull/231) +- Integrated simulation, simulaton and analysis by + [\@chaoming0625](https://github.com/chaoming0625) in + [#232](https://github.com/PKU-NIP-Lab/BrainPy/pull/232) +- update docs by [\@chaoming0625](https://github.com/chaoming0625) in + [#233](https://github.com/PKU-NIP-Lab/BrainPy/pull/233) +- unify `brainpy.layers` with other modules in `brainpy.dyn` by + [\@chaoming0625](https://github.com/chaoming0625) in + [#234](https://github.com/PKU-NIP-Lab/BrainPy/pull/234) +- fix bugs by [\@chaoming0625](https://github.com/chaoming0625) in + [#235](https://github.com/PKU-NIP-Lab/BrainPy/pull/235) +- update apis, docs, examples and others by + [\@chaoming0625](https://github.com/chaoming0625) in + [#236](https://github.com/PKU-NIP-Lab/BrainPy/pull/236) +- fixes by [\@chaoming0625](https://github.com/chaoming0625) in + [#237](https://github.com/PKU-NIP-Lab/BrainPy/pull/237) +- fix: add dtype promotion = standard by + [\@ztqakita](https://github.com/ztqakita) in + [#239](https://github.com/PKU-NIP-Lab/BrainPy/pull/239) +- updates by [\@chaoming0625](https://github.com/chaoming0625) in + [#240](https://github.com/PKU-NIP-Lab/BrainPy/pull/240) +- update training docs by + [\@chaoming0625](https://github.com/chaoming0625) in + [#241](https://github.com/PKU-NIP-Lab/BrainPy/pull/241) +- change doc path/organization by + [\@chaoming0625](https://github.com/chaoming0625) in + [#242](https://github.com/PKU-NIP-Lab/BrainPy/pull/242) +- Update advanced docs by + [\@chaoming0625](https://github.com/chaoming0625) in + [#243](https://github.com/PKU-NIP-Lab/BrainPy/pull/243) +- update quickstart docs & enable jit error checking by + [\@chaoming0625](https://github.com/chaoming0625) in + [#244](https://github.com/PKU-NIP-Lab/BrainPy/pull/244) +- update apis and examples by + [\@chaoming0625](https://github.com/chaoming0625) in + [#245](https://github.com/PKU-NIP-Lab/BrainPy/pull/245) +- update apis and tests by + [\@chaoming0625](https://github.com/chaoming0625) in + [#246](https://github.com/PKU-NIP-Lab/BrainPy/pull/246) +- Docs update and bugs fixed by + [\@ztqakita](https://github.com/ztqakita) in + [#247](https://github.com/PKU-NIP-Lab/BrainPy/pull/247) +- version 2.2.0 by [\@chaoming0625](https://github.com/chaoming0625) + in [#248](https://github.com/PKU-NIP-Lab/BrainPy/pull/248) +- add norm and pooling & fix bugs in operators by + [\@ztqakita](https://github.com/ztqakita) in + [#249](https://github.com/PKU-NIP-Lab/BrainPy/pull/249) + +**Full Changelog**: +[V2.1.12\...V2.2.0](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.12...V2.2.0) + +## brainpy 2.1.x + +### Version 2.1.12 (2022.05.17) + +#### Highlights + +This release is excellent. We have made important improvements. + +1. We provide dozens of random sampling in NumPy which are not + supportted in JAX, such as `brainpy.math.random.bernoulli`, + `brainpy.math.random.lognormal`, `brainpy.math.random.binomial`, + `brainpy.math.random.chisquare`, `brainpy.math.random.dirichlet`, + `brainpy.math.random.geometric`, `brainpy.math.random.f`, + `brainpy.math.random.hypergeometric`, + `brainpy.math.random.logseries`, `brainpy.math.random.multinomial`, + `brainpy.math.random.multivariate_normal`, + `brainpy.math.random.negative_binomial`, + `brainpy.math.random.noncentral_chisquare`, + `brainpy.math.random.noncentral_f`, `brainpy.math.random.power`, + `brainpy.math.random.rayleigh`, `brainpy.math.random.triangular`, + `brainpy.math.random.vonmises`, `brainpy.math.random.wald`, + `brainpy.math.random.weibull` +2. make efficient checking on numerical values. Instead of direct + `id_tap()` checking which has large overhead, currently + `brainpy.tools.check_erro_in_jit()` is highly efficient. +3. Fix `JaxArray` operator errors on `None` +4. improve oo-to-function transformation speeds +5. `io` works: `.save_states()` and `.load_states()` + +#### What's Changed + +- support dtype setting in array interchange functions by + \[@chaoming0625\]() in + [#209](https://github.com/PKU-NIP-Lab/BrainPy/pull/209) +- fix [#144](https://github.com/PKU-NIP-Lab/BrainPy/issues/144): + operations on None raise errors by + \[@chaoming0625\]() in + [#210](https://github.com/PKU-NIP-Lab/BrainPy/pull/210) +- add tests and new functions for random sampling by + \[@c-xy17\]() in + [#213](https://github.com/PKU-NIP-Lab/BrainPy/pull/213) +- feat: fix `io` for brainpy.Base by + \[@chaoming0625\]() in + [#211](https://github.com/PKU-NIP-Lab/BrainPy/pull/211) +- update advanced tutorial documentation by + \[@chaoming0625\]() in + [#212](https://github.com/PKU-NIP-Lab/BrainPy/pull/212) +- fix [#149](https://github.com/PKU-NIP-Lab/BrainPy/issues/149) + (dozens of random samplings in NumPy) and fix JaxArray op errors by + \[@chaoming0625\]() in + [#216](https://github.com/PKU-NIP-Lab/BrainPy/pull/216) +- feat: efficient checking on numerical values by + \[@chaoming0625\]() in + [#217](https://github.com/PKU-NIP-Lab/BrainPy/pull/217) + +**Full Changelog**: +[V2.1.11\...V2.1.12](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.11...V2.1.12) + +### Version 2.1.11 (2022.05.15) + +#### What\'s Changed + +- fix: cross-correlation bug by + [\@ztqakita](https://github.com/ztqakita) in + [#201](https://github.com/PKU-NIP-Lab/BrainPy/pull/201) +- update apis, test and docs of numpy ops by + [\@chaoming0625](https://github.com/chaoming0625) in + [#202](https://github.com/PKU-NIP-Lab/BrainPy/pull/202) +- docs: add sphinx_book_theme by + [\@ztqakita](https://github.com/ztqakita) in + [#203](https://github.com/PKU-NIP-Lab/BrainPy/pull/203) +- fix: add requirements-doc.txt by + [\@ztqakita](https://github.com/ztqakita) in + [#204](https://github.com/PKU-NIP-Lab/BrainPy/pull/204) +- update control flow, integrators, operators, and docs by + [\@chaoming0625](https://github.com/chaoming0625) in + [#205](https://github.com/PKU-NIP-Lab/BrainPy/pull/205) +- improve oo-to-function transformation speed by + [\@chaoming0625](https://github.com/chaoming0625) in + [#208](https://github.com/PKU-NIP-Lab/BrainPy/pull/208) + +**Full Changelog**: +[V2.1.10\...V2.1.11](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.10...V2.1.11) + +### Version 2.1.10 (2022.05.05) + +#### What\'s Changed + +- update control flow APIs and Docs by + [\@chaoming0625](https://github.com/chaoming0625) in + [#192](https://github.com/PKU-NIP-Lab/BrainPy/pull/192) +- doc: update docs of dynamics simulation by + [\@chaoming0625](https://github.com/chaoming0625) in + [#193](https://github.com/PKU-NIP-Lab/BrainPy/pull/193) +- fix [#125](https://github.com/PKU-NIP-Lab/BrainPy/issues/125): add + channel models and two-compartment Pinsky-Rinzel model by + [\@chaoming0625](https://github.com/chaoming0625) in + [#194](https://github.com/PKU-NIP-Lab/BrainPy/pull/194) +- JIT errors do not change Variable values by + [\@chaoming0625](https://github.com/chaoming0625) in + [#195](https://github.com/PKU-NIP-Lab/BrainPy/pull/195) +- fix a bug in math.activations.py by + [\@c-xy17](https://github.com/c-xy17) in + [#196](https://github.com/PKU-NIP-Lab/BrainPy/pull/196) +- Functionalinaty improvements by + [\@chaoming0625](https://github.com/chaoming0625) in + [#197](https://github.com/PKU-NIP-Lab/BrainPy/pull/197) +- update rate docs by + [\@chaoming0625](https://github.com/chaoming0625) in + [#198](https://github.com/PKU-NIP-Lab/BrainPy/pull/198) +- update brainpy.dyn doc by + [\@chaoming0625](https://github.com/chaoming0625) in + [#199](https://github.com/PKU-NIP-Lab/BrainPy/pull/199) + +**Full Changelog**: +[V2.1.8\...V2.1.10](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.8...V2.1.10) + +### Version 2.1.8 (2022.04.26) + +#### What\'s Changed + +- Fix [#120](https://github.com/PKU-NIP-Lab/BrainPy/issues/120) by + [\@chaoming0625](https://github.com/chaoming0625) in + [#178](https://github.com/PKU-NIP-Lab/BrainPy/pull/178) +- feat: brainpy.Collector supports addition and subtraction by + [\@chaoming0625](https://github.com/chaoming0625) in + [#179](https://github.com/PKU-NIP-Lab/BrainPy/pull/179) +- feat: delay variables support \"indices\" and \"reset()\" function + by [\@chaoming0625](https://github.com/chaoming0625) in + [#180](https://github.com/PKU-NIP-Lab/BrainPy/pull/180) +- Support reset functions in neuron and synapse models by + [\@chaoming0625](https://github.com/chaoming0625) in + [#181](https://github.com/PKU-NIP-Lab/BrainPy/pull/181) +- `update()` function on longer need `_t` and `_dt` by + [\@chaoming0625](https://github.com/chaoming0625) in + [#183](https://github.com/PKU-NIP-Lab/BrainPy/pull/183) +- small updates by [\@chaoming0625](https://github.com/chaoming0625) + in [#188](https://github.com/PKU-NIP-Lab/BrainPy/pull/188) +- feat: easier control flows with `brainpy.math.ifelse` by + [\@chaoming0625](https://github.com/chaoming0625) in + [#189](https://github.com/PKU-NIP-Lab/BrainPy/pull/189) +- feat: update delay couplings of `DiffusiveCoupling` and + `AdditiveCouping` by + [\@chaoming0625](https://github.com/chaoming0625) in + [#190](https://github.com/PKU-NIP-Lab/BrainPy/pull/190) +- update version and changelog by + [\@chaoming0625](https://github.com/chaoming0625) in + [#191](https://github.com/PKU-NIP-Lab/BrainPy/pull/191) + +**Full Changelog**: +[V2.1.7\...V2.1.8](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.7...V2.1.8) + +### Version 2.1.7 (2022.04.22) + +#### What\'s Changed + +- synapse models support heterogeneuos weights by + [\@chaoming0625](https://github.com/chaoming0625) in + [#170](https://github.com/PKU-NIP-Lab/BrainPy/pull/170) +- more efficient synapse implementation by + [\@chaoming0625](https://github.com/chaoming0625) in + [#171](https://github.com/PKU-NIP-Lab/BrainPy/pull/171) +- fix input models in brainpy.dyn by + [\@chaoming0625](https://github.com/chaoming0625) in + [#172](https://github.com/PKU-NIP-Lab/BrainPy/pull/172) +- fix: np array astype by [\@ztqakita](https://github.com/ztqakita) in + [#173](https://github.com/PKU-NIP-Lab/BrainPy/pull/173) +- update README: \'brain-py\' to \'brainpy\' by + [\@chaoming0625](https://github.com/chaoming0625) in + [#174](https://github.com/PKU-NIP-Lab/BrainPy/pull/174) +- fix: fix the updating rules in the STP model by + [\@c-xy17](https://github.com/c-xy17) in + [#176](https://github.com/PKU-NIP-Lab/BrainPy/pull/176) +- Updates and fixes by + [\@chaoming0625](https://github.com/chaoming0625) in + [#177](https://github.com/PKU-NIP-Lab/BrainPy/pull/177) + +**Full Changelog**: +[V2.1.5\...V2.1.7](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.5...V2.1.7) + +### Version 2.1.5 (2022.04.18) + +#### What\'s Changed + +- `brainpy.math.random.shuffle` is numpy like by + [\@chaoming0625](https://github.com/chaoming0625) in + [#153](https://github.com/PKU-NIP-Lab/BrainPy/pull/153) +- update LICENSE by [\@chaoming0625](https://github.com/chaoming0625) + in [#155](https://github.com/PKU-NIP-Lab/BrainPy/pull/155) +- docs: add m1 warning by [\@ztqakita](https://github.com/ztqakita) in + [#154](https://github.com/PKU-NIP-Lab/BrainPy/pull/154) +- compatible apis of \'brainpy.math\' with those of \'jax.numpy\' in + most modules by [\@chaoming0625](https://github.com/chaoming0625) in + [#156](https://github.com/PKU-NIP-Lab/BrainPy/pull/156) +- Important updates by + [\@chaoming0625](https://github.com/chaoming0625) in + [#157](https://github.com/PKU-NIP-Lab/BrainPy/pull/157) +- Updates by [\@chaoming0625](https://github.com/chaoming0625) in + [#159](https://github.com/PKU-NIP-Lab/BrainPy/pull/159) +- Add LayerNorm, GroupNorm, and InstanceNorm as nn_nodes in + normalization.py by [\@c-xy17](https://github.com/c-xy17) in + [#162](https://github.com/PKU-NIP-Lab/BrainPy/pull/162) +- feat: add conv & pooling nodes by + [\@ztqakita](https://github.com/ztqakita) in + [#161](https://github.com/PKU-NIP-Lab/BrainPy/pull/161) +- fix: update setup.py by [\@ztqakita](https://github.com/ztqakita) in + [#163](https://github.com/PKU-NIP-Lab/BrainPy/pull/163) +- update setup.py by [\@chaoming0625](https://github.com/chaoming0625) + in [#165](https://github.com/PKU-NIP-Lab/BrainPy/pull/165) +- fix: change trigger condition by + [\@ztqakita](https://github.com/ztqakita) in + [#166](https://github.com/PKU-NIP-Lab/BrainPy/pull/166) +- fix: add build_conn() function by + [\@ztqakita](https://github.com/ztqakita) in + [#164](https://github.com/PKU-NIP-Lab/BrainPy/pull/164) +- update synapses by [\@chaoming0625](https://github.com/chaoming0625) + in [#167](https://github.com/PKU-NIP-Lab/BrainPy/pull/167) +- get the deserved name: brainpy by + [\@chaoming0625](https://github.com/chaoming0625) in + [#168](https://github.com/PKU-NIP-Lab/BrainPy/pull/168) +- update tests by [\@chaoming0625](https://github.com/chaoming0625) in + [#169](https://github.com/PKU-NIP-Lab/BrainPy/pull/169) + +**Full Changelog**: +[V2.1.4\...V2.1.5](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.4...V2.1.5) + +### Version 2.1.4 (2022.04.04) + +#### What\'s Changed + +- fix doc parsing bug by + [\@chaoming0625](https://github.com/chaoming0625) in + [#127](https://github.com/PKU-NIP-Lab/BrainPy/pull/127) +- Update overview_of_dynamic_model.ipynb by + [\@c-xy17](https://github.com/c-xy17) in + [#129](https://github.com/PKU-NIP-Lab/BrainPy/pull/129) +- Reorganization of `brainpylib.custom_op` and adding interface in + `brainpy.math` by [\@ztqakita](https://github.com/ztqakita) in + [#128](https://github.com/PKU-NIP-Lab/BrainPy/pull/128) +- Fix: modify `register_op` and brainpy.math interface by + [\@ztqakita](https://github.com/ztqakita) in + [#130](https://github.com/PKU-NIP-Lab/BrainPy/pull/130) +- new features about RNN training and delay differential equations by + [\@chaoming0625](https://github.com/chaoming0625) in + [#132](https://github.com/PKU-NIP-Lab/BrainPy/pull/132) +- Fix [#123](https://github.com/PKU-NIP-Lab/BrainPy/issues/123): Add + low-level operators docs and modify register_op by + [\@ztqakita](https://github.com/ztqakita) in + [#134](https://github.com/PKU-NIP-Lab/BrainPy/pull/134) +- feat: add generate_changelog by + [\@ztqakita](https://github.com/ztqakita) in + [#135](https://github.com/PKU-NIP-Lab/BrainPy/pull/135) +- fix [#133](https://github.com/PKU-NIP-Lab/BrainPy/issues/133), + support batch size training with offline algorithms by + [\@chaoming0625](https://github.com/chaoming0625) in + [#136](https://github.com/PKU-NIP-Lab/BrainPy/pull/136) +- fix [#84](https://github.com/PKU-NIP-Lab/BrainPy/issues/84): support + online training algorithms by + [\@chaoming0625](https://github.com/chaoming0625) in + [#137](https://github.com/PKU-NIP-Lab/BrainPy/pull/137) +- feat: add the batch normalization node by + [\@c-xy17](https://github.com/c-xy17) in + [#138](https://github.com/PKU-NIP-Lab/BrainPy/pull/138) +- fix: fix shape checking error by + [\@chaoming0625](https://github.com/chaoming0625) in + [#139](https://github.com/PKU-NIP-Lab/BrainPy/pull/139) +- solve [#131](https://github.com/PKU-NIP-Lab/BrainPy/issues/131), + support efficient synaptic computation for special connection types + by [\@chaoming0625](https://github.com/chaoming0625) in + [#140](https://github.com/PKU-NIP-Lab/BrainPy/pull/140) +- feat: update the API and test for batch normalization by + [\@c-xy17](https://github.com/c-xy17) in + [#142](https://github.com/PKU-NIP-Lab/BrainPy/pull/142) +- Node is default trainable by + [\@chaoming0625](https://github.com/chaoming0625) in + [#143](https://github.com/PKU-NIP-Lab/BrainPy/pull/143) +- Updates training apis and docs by + [\@chaoming0625](https://github.com/chaoming0625) in + [#145](https://github.com/PKU-NIP-Lab/BrainPy/pull/145) +- fix: add dependencies and update version by + [\@ztqakita](https://github.com/ztqakita) in + [#147](https://github.com/PKU-NIP-Lab/BrainPy/pull/147) +- update requirements by + [\@chaoming0625](https://github.com/chaoming0625) in + [#146](https://github.com/PKU-NIP-Lab/BrainPy/pull/146) +- data pass of the Node is default SingleData by + [\@chaoming0625](https://github.com/chaoming0625) in + [#148](https://github.com/PKU-NIP-Lab/BrainPy/pull/148) + +**Full Changelog**: +[V2.1.3\...V2.1.4](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.3...V2.1.4) + +### Version 2.1.3 (2022.03.27) + +This release improves the functionality and usability of BrainPy. Core +changes include + +- support customization of low-level operators by using Numba +- fix bugs + +#### What\'s Changed + +- Provide custom operators written in numba for jax jit by + [\@ztqakita](https://github.com/ztqakita) in + [#122](https://github.com/PKU-NIP-Lab/BrainPy/pull/122) +- fix DOGDecay bugs, add more features by + [\@chaoming0625](https://github.com/chaoming0625) in + [#124](https://github.com/PKU-NIP-Lab/BrainPy/pull/124) +- fix bugs by [\@chaoming0625](https://github.com/chaoming0625) in + [#126](https://github.com/PKU-NIP-Lab/BrainPy/pull/126) + +**Full Changelog** : +[V2.1.2\...V2.1.3](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.2...V2.1.3) + +### Version 2.1.2 (2022.03.23) + +This release improves the functionality and usability of BrainPy. Core +changes include + +- support rate-based whole-brain modeling +- add more neuron models, including rate neurons/synapses +- support Python 3.10 +- improve delays etc. APIs + +#### What\'s Changed + +- fix matplotlib dependency on \"brainpy.analysis\" module by + [\@chaoming0625](https://github.com/chaoming0625) in + [#110](https://github.com/PKU-NIP-Lab/BrainPy/pull/110) +- Sync master to brainpy-2.x branch by + [\@ztqakita](https://github.com/ztqakita) in + [#111](https://github.com/PKU-NIP-Lab/BrainPy/pull/111) +- add py3.6 test & delete multiple macos env by + [\@ztqakita](https://github.com/ztqakita) in + [#112](https://github.com/PKU-NIP-Lab/BrainPy/pull/112) +- Modify ci by [\@ztqakita](https://github.com/ztqakita) in + [#113](https://github.com/PKU-NIP-Lab/BrainPy/pull/113) +- Add py3.10 test by [\@ztqakita](https://github.com/ztqakita) in + [#115](https://github.com/PKU-NIP-Lab/BrainPy/pull/115) +- update python version by + [\@chaoming0625](https://github.com/chaoming0625) in + [#114](https://github.com/PKU-NIP-Lab/BrainPy/pull/114) +- add brainpylib mac py3.10 by + [\@ztqakita](https://github.com/ztqakita) in + [#116](https://github.com/PKU-NIP-Lab/BrainPy/pull/116) +- Enhance measure/input/brainpylib by + [\@chaoming0625](https://github.com/chaoming0625) in + [#117](https://github.com/PKU-NIP-Lab/BrainPy/pull/117) +- fix [#105](https://github.com/PKU-NIP-Lab/BrainPy/issues/105): Add + customize connections docs by + [\@ztqakita](https://github.com/ztqakita) in + [#118](https://github.com/PKU-NIP-Lab/BrainPy/pull/118) +- fix bugs by [\@chaoming0625](https://github.com/chaoming0625) in + [#119](https://github.com/PKU-NIP-Lab/BrainPy/pull/119) +- Whole brain modeling by + [\@chaoming0625](https://github.com/chaoming0625) in + [#121](https://github.com/PKU-NIP-Lab/BrainPy/pull/121) + +**Full Changelog**: +[V2.1.1\...V2.1.2](https://github.com/PKU-NIP-Lab/BrainPy/compare/V2.1.1...V2.1.2) + +### Version 2.1.1 (2022.03.18) + +This release continues to update the functionality of BrainPy. Core +changes include + +- numerical solvers for fractional differential equations +- more standard `brainpy.nn` interfaces + +#### New Features + +- + +Numerical solvers for fractional differential equations + +: - `brainpy.fde.CaputoEuler` +- `brainpy.fde.CaputoL1Schema` +- `brainpy.fde.GLShortMemory` + +- + +Fractional neuron models + +: - `brainpy.dyn.FractionalFHR` +- `brainpy.dyn.FractionalIzhikevich` + +- support `shared_kwargs` in [RNNTrainer]{.title-ref} and + [RNNRunner]{.title-ref} + +### Version 2.1.0 (2022.03.14) + +#### Highlights + +We are excited to announce the release of BrainPy 2.1.0. This release is +composed of nearly 270 commits since 2.0.2, made by [Chaoming +Wang](https://github.com/chaoming0625), [Xiaoyu +Chen](mailto:c-xy17@tsinghua.org.cn), and [Tianqiu +Zhang](mailto:tianqiuakita@gmail.com) . + +BrainPy 2.1.0 updates are focused on improving usability, functionality, +and stability of BrainPy. Highlights of version 2.1.0 include: + +- New module `brainpy.dyn` for dynamics building and simulation. It is + composed of many neuron models, synapse models, and others. +- New module `brainpy.nn` for neural network building and training. It + supports to define reservoir models, artificial neural networks, + ridge regression training, and back-propagation through time + training. +- New module `brainpy.datasets` for convenient dataset construction + and initialization. +- New module `brainpy.integrators.dde` for numerical integration of + delay differential equations. +- Add more numpy-like operators in `brainpy.math` module. +- Add automatic continuous integration on Linux, Windows, and MacOS + platforms. +- Fully update brainpy documentation. +- Fix bugs on `brainpy.analysis` and `brainpy.math.autograd` + +#### Incompatible changes + +- Remove `brainpy.math.numpy` module. +- Remove numba requirements +- Remove matplotlib requirements +- Remove [steps]{.title-ref} in `brainpy.dyn.DynamicalSystem` +- Remove travis CI + +#### New Features + +- `brainpy.ddeint` for numerical integration of delay differential + equations, the supported methods include: - Euler - MidPoint - + Heun2 - Ralston2 - RK2 - RK3 - Heun3 - Ralston3 - SSPRK3 - RK4 - + Ralston4 - RK4Rule38 + +- + +set default int/float/complex types + +: - `brainpy.math.set_dfloat()` +- `brainpy.math.set_dint()` +- `brainpy.math.set_dcomplex()` + +- + +Delay variables + +: - `brainpy.math.FixedLenDelay` +- `brainpy.math.NeutralDelay` + +- + +Dedicated operators + +: - `brainpy.math.sparse_matmul()` + +- More numpy-like operators + +- Neural network building `brainpy.nn` + +- Dynamics model building and simulation `brainpy.dyn` + +### Version 2.0.2 (2022.02.11) + +There are important updates by [Chaoming +Wang](https://github.com/chaoming0625) in BrainPy 2.0.2. + +- provide `pre2post_event_prod` operator +- support array creation from a list/tuple of JaxArray in + `brainpy.math.asarray` and `brainpy.math.array` +- update `brainpy.ConstantDelay`, add `.latest` and `.oldest` + attributes +- add `brainpy.IntegratorRunner` support for efficient simulation of + brainpy integrators +- support auto finding of RandomState when JIT SDE integrators +- fix bugs in SDE `exponential_euler` method +- move `parallel` running APIs into `brainpy.simulation` +- add `brainpy.math.syn2post_mean`, `brainpy.math.syn2post_softmax`, + `brainpy.math.pre2post_mean` and `brainpy.math.pre2post_softmax` + operators + +### Version 2.0.1 (2022.01.31) + +Today we release BrainPy 2.0.1. This release is composed of over 70 +commits since 2.0.0, made by [Chaoming +Wang](https://github.com/chaoming0625), [Xiaoyu +Chen](mailto:c-xy17@tsinghua.org.cn), and [Tianqiu +Zhang](mailto:tianqiuakita@gmail.com) . + +BrainPy 2.0.0 updates are focused on improving documentation and +operators. Core changes include: + +- Improve `brainpylib` operators +- Complete documentation for programming system +- Add more numpy APIs +- Add `jaxfwd` in autograd module +- And other changes + +### Version 2.0.0.1 (2022.01.05) + +- Add progress bar in `brainpy.StructRunner` + +### Version 2.0.0 (2021.12.31) + +Start a new version of BrainPy. + +#### Highlight + +We are excited to announce the release of BrainPy 2.0.0. This release is +composed of over 260 commits since 1.1.7, made by [Chaoming +Wang](https://github.com/chaoming0625), [Xiaoyu +Chen](mailto:c-xy17@tsinghua.org.cn), and [Tianqiu +Zhang](mailto:tianqiuakita@gmail.com) . + +BrainPy 2.0.0 updates are focused on improving performance, usability +and consistence of BrainPy. All the computations are migrated into JAX. +Model `building`, `simulation`, `training` and `analysis` are all based +on JAX. Highlights of version 2.0.0 include: + +- [brainpylib](https://pypi.org/project/brainpylib/) are provided to + dedicated operators for brain dynamics programming +- Connection APIs in `brainpy.conn` module are more efficient. +- Update analysis tools for low-dimensional and high-dimensional + systems in `brainpy.analysis` module. +- Support more general Exponential Euler methods based on automatic + differentiation. +- Improve the usability and consistence of `brainpy.math` module. +- Remove JIT compilation based on Numba. +- Separate brain building with brain simulation. + +#### Incompatible changes + +- remove `brainpy.math.use_backend()` +- remove `brainpy.math.numpy` module +- no longer support `.run()` in `brainpy.DynamicalSystem` (see New + Features) +- remove `brainpy.analysis.PhasePlane` (see New Features) +- remove `brainpy.analysis.Bifurcation` (see New Features) +- remove `brainpy.analysis.FastSlowBifurcation` (see New Features) + +#### New Features + +- + +Exponential Euler method based on automatic differentiation + +: - `brainpy.ode.ExpEulerAuto` + +- + +Numerical optimization based low-dimensional analyzers: + +: - `brainpy.analysis.PhasePlane1D` +- `brainpy.analysis.PhasePlane2D` +- `brainpy.analysis.Bifurcation1D` +- `brainpy.analysis.Bifurcation2D` +- `brainpy.analysis.FastSlow1D` +- `brainpy.analysis.FastSlow2D` + +- + +Numerical optimization based high-dimensional analyzer: + +: - `brainpy.analysis.SlowPointFinder` + +- + +Dedicated operators in `brainpy.math` module: + +: - `brainpy.math.pre2post_event_sum` +- `brainpy.math.pre2post_sum` +- `brainpy.math.pre2post_prod` +- `brainpy.math.pre2post_max` +- `brainpy.math.pre2post_min` +- `brainpy.math.pre2syn` +- `brainpy.math.syn2post` +- `brainpy.math.syn2post_prod` +- `brainpy.math.syn2post_max` +- `brainpy.math.syn2post_min` + +- + +Conversion APIs in `brainpy.math` module: + +: - `brainpy.math.as_device_array()` +- `brainpy.math.as_variable()` +- `brainpy.math.as_jaxarray()` + +- + +New autograd APIs in `brainpy.math` module: + +: - `brainpy.math.vector_grad()` + +- + +Simulation runners: + +: - `brainpy.ReportRunner` +- `brainpy.StructRunner` +- `brainpy.NumpyRunner` + +- + +Commonly used models in `brainpy.models` module + +: - `brainpy.models.LIF` +- `brainpy.models.Izhikevich` +- `brainpy.models.AdExIF` +- `brainpy.models.SpikeTimeInput` +- `brainpy.models.PoissonInput` +- `brainpy.models.DeltaSynapse` +- `brainpy.models.ExpCUBA` +- `brainpy.models.ExpCOBA` +- `brainpy.models.AMPA` +- `brainpy.models.GABAa` + +- Naming cache clean: `brainpy.clear_name_cache` + +- add safe in-place operations of `update()` method and `.value` + assignment for JaxArray + +#### Documentation + +- Complete tutorials for quickstart +- Complete tutorials for dynamics building +- Complete tutorials for dynamics simulation +- Complete tutorials for dynamics training +- Complete tutorials for dynamics analysis +- Complete tutorials for API documentation + +## brainpy 1.1.x + +If you are using `brainpy==1.x`, you can find *documentation*, +*examples*, and *models* through the following links: + +- **Documentation:** +- **Examples from papers**: + +- **Canonical brain models**: + + +### Version 1.1.7 (2021.12.13) + +- fix bugs on `numpy_array()` conversion in + [brainpy.math.utils]{.title-ref} module + +### Version 1.1.5 (2021.11.17) + +**API changes:** + +- fix bugs on ndarray import in [brainpy.base.function.py]{.title-ref} +- convenient \'get_param\' interface + [brainpy.simulation.layers]{.title-ref} +- add more weight initialization methods + +**Doc changes:** + +- add more examples in README + +### Version 1.1.4 + +**API changes:** + +- add `.struct_run()` in DynamicalSystem +- add `numpy_array()` conversion in [brainpy.math.utils]{.title-ref} + module +- add `Adagrad`, `Adadelta`, `RMSProp` optimizers +- remove [setting]{.title-ref} methods in + [brainpy.math.jax]{.title-ref} module +- remove import jax in [brainpy.\_\_init\_\_.py]{.title-ref} and + enable jax setting, including + - `enable_x64()` + - `set_platform()` + - `set_host_device_count()` +- enable `b=None` as no bias in + [brainpy.simulation.layers]{.title-ref} +- set [int\_]{.title-ref} and [float\_]{.title-ref} as default 32 bits +- remove `dtype` setting in Initializer constructor + +**Doc changes:** + +- add `optimizer` in \"Math Foundation\" +- add `dynamics training` docs +- improve others + +### Version 1.1.3 + +- fix bugs of JAX parallel API imports +- fix bugs of [post_slice]{.title-ref} structure construction +- update docs + +### Version 1.1.2 + +- add `pre2syn` and `syn2post` operators +- add [verbose]{.title-ref} and [check]{.title-ref} option to + `Base.load_states()` +- fix bugs on JIT DynamicalSystem (numpy backend) + +### Version 1.1.1 + +- fix bugs on symbolic analysis: model trajectory +- change [absolute]{.title-ref} access in the variable saving and + loading to the [relative]{.title-ref} access +- add UnexpectedTracerError hints in JAX transformation functions + +### Version 1.1.0 (2021.11.08) + +This package releases a new version of BrainPy. + +Highlights of core changes: + +#### `math` module + +- support numpy backend +- support JAX backend +- support `jit`, `vmap` and `pmap` on class objects on JAX backend +- support `grad`, `jacobian`, `hessian` on class objects on JAX + backend +- support `make_loop`, `make_while`, and `make_cond` on JAX backend +- support `jit` (based on numba) on class objects on numpy backend +- unified numpy-like ndarray operation APIs +- numpy-like random sampling APIs +- FFT functions +- gradient descent optimizers +- activation functions +- loss function +- backend settings + +#### `base` module + +- `Base` for whole Version ecosystem +- `Function` to wrap functions +- `Collector` and `TensorCollector` to collect variables, integrators, + nodes and others + +#### `integrators` module + +- class integrators for ODE numerical methods +- class integrators for SDE numerical methods + +#### `simulation` module + +- support modular and composable programming +- support multi-scale modeling +- support large-scale modeling +- support simulation on GPUs +- fix bugs on `firing_rate()` +- remove `_i` in `update()` function, replace `_i` with `_dt`, meaning + the dynamic system has the canonic equation form of + $dx/dt = f(x, t, dt)$ +- reimplement the `input_step` and `monitor_step` in a more intuitive + way +- support to set [dt]{.title-ref} in the single object level (i.e., + single instance of DynamicSystem) +- common used DNN layers +- weight initializations +- refine synaptic connections + +## brainpy 1.0.x + +### Version 1.0.3 (2021.08.18) + +Fix bugs on + +- firing rate measurement +- stability analysis + +### Version 1.0.2 + +This release continues to improve the user-friendliness. + +Highlights of core changes: + +- Remove support for Numba-CUDA backend +- Super initialization [super(XXX, self).\_\_init\_\_()]{.title-ref} + can be done at anywhere (not required to add at the bottom of the + [\_\_init\_\_()]{.title-ref} function). +- Add the output message of the step function running error. +- More powerful support for Monitoring +- More powerful support for running order scheduling +- Remove [unsqueeze()]{.title-ref} and [squeeze()]{.title-ref} + operations in `brainpy.ops` +- Add [reshape()]{.title-ref} operation in `brainpy.ops` +- Improve docs for numerical solvers +- Improve tests for numerical solvers +- Add keywords checking in ODE numerical solvers +- Add more unified operations in brainpy.ops +- Support \"@every\" in steps and monitor functions +- Fix ODE solver bugs for class bounded function +- Add build phase in Monitor + +### Version 1.0.1 + +- Fix bugs + +### Version 1.0.0 + +- **NEW VERSION OF BRAINPY** +- Change the coding style into the object-oriented programming +- Systematically improve the documentation + +## brainpy 0.x + +### Version 0.3.5 + +- Add \'timeout\' in sympy solver in neuron dynamics analysis +- Reconstruct and generalize phase plane analysis +- Generalize the repeat mode of `Network` to different running + duration between two runs +- Update benchmarks +- Update detailed documentation + +### Version 0.3.1 + +- Add a more flexible way for NeuState/SynState initialization +- Fix bugs of \"is_multi_return\" +- Add \"hand_overs\", \"requires\" and \"satisfies\". +- Update documentation +- Auto-transform [range]{.title-ref} to [numba.prange]{.title-ref} +- Support [\_obj_i]{.title-ref}, [\_pre_i]{.title-ref}, + [\_post_i]{.title-ref} for more flexible operation in scalar-based + models + +### Version 0.3.0 + +#### Computation API + +- Rename \"brainpy.numpy\" to \"brainpy.backend\" +- Delete \"pytorch\", \"tensorflow\" backends +- Add \"numba\" requirement +- Add GPU support + +#### Profile setting + +- Delete \"backend\" profile setting, add \"jit\" + +#### Core systems + +- Delete \"autopepe8\" requirement +- Delete the format code prefix +- Change keywords \"\_[t](), \_[dt](), \_[i]()\" to \"\_t, \_dt, \_i\" +- Change the \"ST\" declaration out of \"requires\" +- Add \"repeat\" mode run in Network +- Change \"vector-based\" to \"mode\" in NeuType and SynType + definition + +#### Package installation + +- Remove \"pypi\" installation, installation now only rely on + \"conda\" + +### Version 0.2.4 + +#### API changes + +- Fix bugs + +### Version 0.2.3 + +#### API changes + +- Add \"animate_1D\" in `visualization` module +- Add \"PoissonInput\", \"SpikeTimeInput\" and \"FreqInput\" in + `inputs` module +- Update phase_portrait_analyzer.py + +#### Models and examples + +- Add CANN examples + +### Version 0.2.2 + +#### API changes + +- Redesign visualization +- Redesign connectivity +- Update docs + +### Version 0.2.1 + +#### API changes + +- Fix bugs in [numba import]{.title-ref} +- Fix bugs in [numpy]{.title-ref} mode with [scalar]{.title-ref} model + +### Version 0.2.0 + +#### API changes + +- For computation: `numpy`, `numba` +- For model definition: `NeuType`, `SynConn` +- For model running: `Network`, `NeuGroup`, `SynConn`, `Runner` +- For numerical integration: `integrate`, `Integrator`, `DiffEquation` +- For connectivity: `One2One`, `All2All`, `GridFour`, `grid_four`, + `GridEight`, `grid_eight`, `GridN`, `FixedPostNum`, `FixedPreNum`, + `FixedProb`, `GaussianProb`, `GaussianWeight`, `DOG` +- For visualization: `plot_value`, `plot_potential`, `plot_raster`, + `animation_potential` +- For measurement: `cross_correlation`, `voltage_fluctuation`, + `raster_plot`, `firing_rate` +- For inputs: `constant_current`, `spike_current`, `ramp_current`. + +#### Models and examples + +- Neuron models: `HH model`, `LIF model`, `Izhikevich model` +- Synapse models: `AMPA`, `GABA`, `NMDA`, `STP`, `GapJunction` +- Network models: `gamma oscillation` diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index f9145a94..26a285cf 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -315,6 +315,9 @@ def __init__( m_initializer: Optional[Union[Callable, ArrayType]] = None, h_initializer: Optional[Union[Callable, ArrayType]] = None, n_initializer: Optional[Union[Callable, ArrayType]] = None, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -340,8 +343,14 @@ def __init__( self._n_initializer = is_initializer(n_initializer, allow_none=True) self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape, num_vars=4) + # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) # model if init_var: @@ -622,6 +631,9 @@ def __init__( V_th: Union[float, ArrayType, Callable] = 10., W_initializer: Union[Callable, ArrayType] = OneInit(0.02), V_initializer: Union[Callable, ArrayType] = Uniform(-70., -60.), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -650,8 +662,13 @@ def __init__( self._W_initializer = is_initializer(W_initializer) self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape, num_vars=2) # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # model if init_var: @@ -895,6 +912,9 @@ def __init__( V_initializer: Union[Callable, ArrayType] = OneInit(-65.), h_initializer: Union[Callable, ArrayType] = OneInit(0.6), n_initializer: Union[Callable, ArrayType] = OneInit(0.32), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -920,8 +940,13 @@ def __init__( self._n_initializer = is_initializer(n_initializer) self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape, num_vars=3) # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # model if init_var: diff --git a/brainpy/_src/dyn/neurons/lif.py b/brainpy/_src/dyn/neurons/lif.py index 11934d9d..30b8b29c 100644 --- a/brainpy/_src/dyn/neurons/lif.py +++ b/brainpy/_src/dyn/neurons/lif.py @@ -7,8 +7,8 @@ from brainpy._src.context import share from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc from brainpy._src.dyn.neurons.base import GradNeuDyn -from brainpy._src.initialize import ZeroInit, OneInit -from brainpy._src.integrators import odeint, JointEq +from brainpy._src.initialize import ZeroInit, OneInit, noise as init_noise +from brainpy._src.integrators import odeint, sdeint, JointEq from brainpy.check import is_initializer from brainpy.types import Shape, ArrayType, Sharding @@ -220,6 +220,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., tau: Union[float, ArrayType, Callable] = 10., V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Optional[Union[float, ArrayType, Callable]] = None, ): # initialization super().__init__(size=size, @@ -244,8 +247,14 @@ def __init__( # initializers self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape) + # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -418,6 +427,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Optional[Union[float, ArrayType, Callable]] = None, ): # initialization super().__init__( @@ -441,6 +453,8 @@ def __init__( R=R, tau=tau, V_initializer=V_initializer, + + noise=noise, ) # parameters @@ -689,6 +703,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., tau: Union[float, ArrayType, Callable] = 10., V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -715,8 +732,13 @@ def __init__( # initializers self._V_initializer = is_initializer(V_initializer) + # noise + self.noise = init_noise(noise, self.varshape) # integral - self.integral = odeint(method=method, f=self.derivative) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -1023,6 +1045,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -1048,6 +1073,7 @@ def __init__( R=R, tau=tau, V_initializer=V_initializer, + noise=noise, ) # parameters @@ -1365,6 +1391,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., V_initializer: Union[Callable, ArrayType] = ZeroInit(), w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -1395,7 +1424,11 @@ def __init__( self._w_initializer = is_initializer(w_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -1700,6 +1733,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -1740,7 +1776,11 @@ def __init__( self._w_initializer = is_initializer(w_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -2011,6 +2051,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., tau: Union[float, ArrayType, Callable] = 10., V_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -2037,7 +2080,11 @@ def __init__( self._V_initializer = is_initializer(V_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=1) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -2280,6 +2327,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -2315,7 +2365,11 @@ def __init__( self._V_initializer = is_initializer(V_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=1) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -2576,6 +2630,9 @@ def __init__( tau_w: Union[float, ArrayType, Callable] = 10., V_initializer: Union[Callable, ArrayType] = ZeroInit(), w_initializer: Union[Callable, ArrayType] = ZeroInit(), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -2605,7 +2662,11 @@ def __init__( self._w_initializer = is_initializer(w_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -2884,6 +2945,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -2923,7 +2987,11 @@ def __init__( self._w_initializer = is_initializer(w_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -3232,6 +3300,9 @@ def __init__( I1_initializer: Union[Callable, ArrayType] = ZeroInit(), I2_initializer: Union[Callable, ArrayType] = ZeroInit(), Vth_initializer: Union[Callable, ArrayType] = OneInit(-50.), + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -3268,7 +3339,11 @@ def __init__( self._Vth_initializer = is_initializer(Vth_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=4) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -3617,6 +3692,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -3665,7 +3743,11 @@ def __init__( self._Vth_initializer = is_initializer(Vth_initializer) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=4) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -3977,6 +4059,9 @@ def __init__( R: Union[float, ArrayType, Callable] = 1., V_initializer: Union[Callable, ArrayType] = OneInit(-70.), u_initializer: Union[Callable, ArrayType] = None, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__(size=size, @@ -4010,7 +4095,11 @@ def __init__( self._u_initializer = is_initializer(u_initializer, allow_none=True) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: @@ -4297,6 +4386,9 @@ def __init__( # new neuron parameter tau_ref: Union[float, ArrayType, Callable] = 0., ref_var: bool = False, + + # noise + noise: Union[float, ArrayType, Callable] = None, ): # initialization super().__init__( @@ -4337,7 +4429,11 @@ def __init__( self._u_initializer = is_initializer(u_initializer, allow_none=True) # integral - self.integral = odeint(method=method, f=self.derivative) + self.noise = init_noise(noise, self.varshape, num_vars=2) + if self.noise is not None: + self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + else: + self.integral = odeint(method=method, f=self.derivative) # variables if init_var: diff --git a/brainpy/_src/dynold/neurons/biological_models.py b/brainpy/_src/dynold/neurons/biological_models.py index 43b2c2a5..8daa7acd 100644 --- a/brainpy/_src/dynold/neurons/biological_models.py +++ b/brainpy/_src/dynold/neurons/biological_models.py @@ -196,15 +196,11 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, **kwargs, ): self.input_var = input_var super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=4) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -302,14 +298,10 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, **kwargs, ): self.input_var = input_var super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -808,14 +800,11 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, + **kwargs, ): self.input_var = input_var super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=3) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): diff --git a/brainpy/_src/dynold/neurons/reduced_models.py b/brainpy/_src/dynold/neurons/reduced_models.py index 9615e1a5..e0eb6b56 100644 --- a/brainpy/_src/dynold/neurons/reduced_models.py +++ b/brainpy/_src/dynold/neurons/reduced_models.py @@ -199,7 +199,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None, spike_fun: Callable = None, **kwargs, ): @@ -207,9 +206,7 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -338,9 +335,7 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -441,7 +436,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Optional[Union[float, ArrayType, Initializer, Callable]] = None, spike_fun: Callable = None, **kwargs, ): @@ -449,9 +443,7 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) + self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -541,7 +533,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, spike_fun: Callable = None, **kwargs, ): @@ -549,9 +540,6 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=1) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -651,7 +639,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, spike_fun: Callable = None, **kwargs, ): @@ -659,9 +646,6 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -769,7 +753,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, spike_fun: Callable = None, **kwargs, ): @@ -777,9 +760,6 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=4) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None): @@ -873,7 +853,6 @@ def __init__( self, *args, input_var: bool = True, - noise: Union[float, ArrayType, Initializer, Callable] = None, spike_fun: Callable = None, **kwargs, ): @@ -881,9 +860,6 @@ def __init__( if spike_fun is not None: kwargs['spk_fun'] = spike_fun super().__init__(*args, **kwargs, init_var=False) - self.noise = init_noise(noise, self.varshape, num_vars=2) - if self.noise is not None: - self.integral = sdeint(method=self.method, f=self.derivative, g=self.noise) self.reset_state(self.mode) def reset_state(self, batch_size=None, **kwargs): diff --git a/brainpylib-changelog.md b/brainpylib-changelog.md new file mode 100644 index 00000000..888a9c68 --- /dev/null +++ b/brainpylib-changelog.md @@ -0,0 +1,62 @@ +# Release notes (``brainpylib``) + +## Version 0.3.0 + +- Fix bugs on windows platform +- remove all customized C++ and CUDA operators + + +## Version 0.2.8 + +- Fix bugs that the DLL cannot be loaded correctly when windows does not have a c++ environment, + +## ~~Version 0.2.7(YANKED)~~ + +## Version 0.2.6 + +- Fix bugs of taichi call function for single result + +## Version 0.2.5 + +- Add new taichi call function for single result on CPU backend + +## Version 0.2.4 + +- Add taichi customized operator call on arm64 backend + +## ~~Version 0.2.3(YANKED)~~ + +## Version 0.2.2 + +- Fix bugs of just-in-time connectivity operators on CPU device + +## Version 0.2.1 + +- Fix bugs of Taichi AOT call on GPU backend by ``cudaMemset()`` CUDA arrays + +## Version 0.2.0 + +- Add XLA custom call from [Taichi](https://github.com/taichi-dev/taichi) AOT (ahead of time) operators on both CPU and + GPU platforms + +## Version 0.0.5 + +- Support operator customization on GPU by ``numba`` + +## Version 0.0.4 + +- Support operator customization on CPU by ``numba`` + +## Version 0.0.3 + +- Support ``event_sum()`` operator on GPU +- Support ``event_prod()`` operator on CPU +- Support ``atomic_sum()`` operator on GPU +- Support ``atomic_prod()`` operator on CPU and GPU + +## Version 0.0.2 + +- Support ``event_sum()`` operator on CPU +- Support ``event_sum2()`` operator on CPU +- Support ``atomic_sum()`` operator on CPU + diff --git a/changelog.rst b/changelog.rst deleted file mode 100644 index c54357f8..00000000 --- a/changelog.rst +++ /dev/null @@ -1,1083 +0,0 @@ -Release notes (brainpy) -####################### - - - - -.. note:: - - All history release notes please see `GitHub releases `_. - - - - -brainpy 2.2.x -************* - -BrainPy 2.2.x is a complete re-design of the framework, -tackling the shortcomings of brainpy 2.1.x generation, -effectively bringing it to research needs and standards. - - - -Version 2.2.1 (2022.09.09) -========================== - -This release fixes bugs found in the codebase and improves the usability and functions of BrainPy. - -Bug fixes -~~~~~~~~~~~~~~ - - -#. Fix the bug of operator customization in ``brainpy.math.XLACustomOp`` and ``brainpy.math.register_op``. Now, it supports operator customization by using NumPy and Numba interface. For instance, - -.. code-block:: python - - import brainpy.math as bm - - def abs_eval(events, indices, indptr, post_val, values): - return post_val - - def con_compute(outs, ins): - post_val = outs - events, indices, indptr, _, values = ins - for i in range(events.size): - if events[i]: - for j in range(indptr[i], indptr[i + 1]): - index = indices[j] - old_value = post_val[index] - post_val[index] = values + old_value - - event_sum = bm.XLACustomOp(eval_shape=abs_eval, con_compute=con_compute) - - -#. Fix the bug of ``brainpy.tools.DotDict``. Now, it is compatible with the transformations of JAX. For instance, - -.. code-block:: python - - import brainpy as bp - from jax import vmap - - @vmap - def multiple_run(I): - hh = bp.neurons.HH(1) - runner = bp.dyn.DSRunner(hh, inputs=('input', I), numpy_mon_after_run=False) - runner.run(100.) - return runner.mon - - mon = multiple_run(bp.math.arange(2, 10, 2)) - -New features -~~~~~~~~~~~~~~ - - -#. Add numpy operators ``brainpy.math.mat``\ , ``brainpy.math.matrix``\ , ``brainpy.math.asmatrix``. -#. Improve translation rules of brainpylib operators, improve its running speeds. -#. Support ``DSView`` of ``DynamicalSystem`` instance. Now, it supports defining models with a slice view of a DS instance. For example, - -.. code-block:: python - - import brainpy as bp - import brainpy.math as bm - - - class EINet_V2(bp.dyn.Network): - def __init__(self, scale=1.0, method='exp_auto'): - super(EINet_V2, self).__init__() - - # network size - num_exc = int(3200 * scale) - num_inh = int(800 * scale) - - # neurons - self.N = bp.neurons.LIF(num_exc + num_inh, - V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - method=method, V_initializer=bp.initialize.Normal(-55., 2.)) - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.Esyn = bp.synapses.Exponential(pre=self.N[:num_exc], post=self.N, - conn=bp.connect.FixedProb(0.02), - g_max=we, tau=5., - output=bp.synouts.COBA(E=0.), - method=method) - self.Isyn = bp.synapses.Exponential(pre=self.N[num_exc:], post=self.N, - conn=bp.connect.FixedProb(0.02), - g_max=wi, tau=10., - output=bp.synouts.COBA(E=-80.), - method=method) - - net = EINet_V2(scale=1., method='exp_auto') - # simulation - runner = bp.dyn.DSRunner( - net, - monitors={'spikes': net.N.spike}, - inputs=[(net.N.input, 20.)] - ) - runner.run(100.) - - # visualization - bp.visualize.raster_plot(runner.mon.ts, runner.mon['spikes'], show=True) - - - - -Version 2.2.0 (2022.08.12) -========================== - - - -This release has provided important improvements for BrainPy, including usability, speed, functions, and others. - -Backwards Incompatible changes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - -1. ``brainpy.nn`` module is no longer supported and has been removed since version 2.2.0. Instead, users should use ``brainpy.train`` module for the training of BP algorithms, online learning, or offline learning algorithms, and ``brainpy.algorithms`` module for online / offline training algorithms. -2. The ``update()`` function for the model definition has been changed: - -.. code-block:: - - >>> # 2.1.x - >>> - >>> import brainpy as bp - >>> - >>> class SomeModel(bp.dyn.DynamicalSystem): - >>> def __init__(self, ): - >>> ...... - >>> def update(self, t, dt): - >>> pass - >>> # 2.2.x - >>> - >>> import brainpy as bp - >>> - >>> class SomeModel(bp.dyn.DynamicalSystem): - >>> def __init__(self, ): - >>> ...... - >>> def update(self, tdi): - >>> t, dt = tdi.t, tdi.dt - >>> pass - -where ``tdi`` can be defined with other names, like ``sha``\ , to represent the shared argument across modules. - -Deprecations -~~~~~~~~~~~~~~~~~~~~ - - -#. ``brainpy.dyn.xxx (neurons)`` and ``brainpy.dyn.xxx (synapse)`` are no longer supported. Please use ``brainpy.neurons``\ , ``brainpy.synapses`` modules. -#. ``brainpy.running.monitor`` has been removed. -#. ``brainpy.nn`` module has been removed. - -New features -~~~~~~~~~~~~~~~~~~~~ - - -1. ``brainpy.math.Variable`` receives a ``batch_axis`` setting to represent the batch axis of the data. - -.. code-block:: - - >>> import brainpy.math as bm - >>> a = bm.Variable(bm.zeros((1, 4, 5)), batch_axis=0) - >>> a.value = bm.zeros((2, 4, 5)) # success - >>> a.value = bm.zeros((1, 2, 5)) # failed - MathError: The shape of the original data is (2, 4, 5), while we got (1, 2, 5) with batch_axis=0. - - -2. ``brainpy.train`` provides ``brainpy.train.BPTT`` for back-propagation algorithms, ``brainpy.train.Onlinetrainer`` for online training algorithms, ``brainpy.train.OfflineTrainer`` for offline training algorithms. -3. ``brainpy.Base`` class supports ``_excluded_vars`` setting to ignore variables when retrieving variables by using ``Base.vars()`` method. - -.. code-block:: - - >>> class OurModel(bp.Base): - >>> _excluded_vars = ('a', 'b') - >>> def __init__(self): - >>> super(OurModel, self).__init__() - >>> self.a = bm.Variable(bm.zeros(10)) - >>> self.b = bm.Variable(bm.ones(20)) - >>> self.c = bm.Variable(bm.random.random(10)) - >>> - >>> model = OurModel() - >>> model.vars().keys() - dict_keys(['OurModel0.c']) - - -4. ``brainpy.analysis.SlowPointFinder`` supports directly analyzing an instance of ``brainpy.dyn.DynamicalSystem``. - -.. code-block:: - - >>> hh = bp.neurons.HH(1) - >>> finder = bp.analysis.SlowPointFinder(hh, target_vars={'V': hh.V, 'm': hh.m, 'h': hh.h, 'n': hh.n}) - - -5. ``brainpy.datasets`` supports MNIST, FashionMNIST, and other datasets. -6. Supports defining conductance-based neuron models``. - -.. code-block:: - - >>> class HH(bp.dyn.CondNeuGroup): - >>> def __init__(self, size): - >>> super(HH, self).__init__(size) - >>> - >>> self.INa = channels.INa_HH1952(size, ) - >>> self.IK = channels.IK_HH1952(size, ) - >>> self.IL = channels.IL(size, E=-54.387, g_max=0.03) - - -7. ``brainpy.layers`` module provides commonly used models for DNN and reservoir computing. -8. Support composable definition of synaptic models by using ``TwoEndConn``\ , ``SynOut``\ , ``SynSTP`` and ``SynLTP``. - -.. code-block:: - - >>> bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(prob), - >>> g_max=0.03 / scale, tau=5, - >>> output=bp.synouts.COBA(E=0.), - >>> stp=bp.synplast.STD()) - - -9. Provide commonly used surrogate gradient function for spiking generation, including - - * ``brainpy.math.spike_with_sigmoid_grad`` - * ``brainpy.math.spike_with_linear_grad`` - * ``brainpy.math.spike_with_gaussian_grad`` - * ``brainpy.math.spike_with_mg_grad`` - -10. Provide shortcuts for GPU memory management via ``brainpy.math.disable_gpu_memory_preallocation()``\ , and ``brainpy.math.clear_buffer_memory()``. - -What's Changed -~~~~~~~~~~~~~~~~~~~~ - - -* fix `#207 `_\ : synapses update first, then neurons, finally delay variables by `@chaoming0625 `_ in `#219 `_ -* docs: add logos by `@ztqakita `_ in `#218 `_ -* Add the biological NMDA model by `@c-xy17 `_ in `#221 `_ -* docs: fix mathjax problem by `@ztqakita `_ in `#222 `_ -* Add the parameter R to the LIF model by `@c-xy17 `_ in `#224 `_ -* new version of brainpy: V2.2.0-rc1 by `@chaoming0625 `_ in `#226 `_ -* update training apis by `@chaoming0625 `_ in `#227 `_ -* Update quickstart and the analysis module by `@c-xy17 `_ in `#229 `_ -* Eseential updates for montors, analysis, losses, and examples by `@chaoming0625 `_ in `#230 `_ -* add numpy op tests by `@ztqakita `_ in `#231 `_ -* Integrated simulation, simulaton and analysis by `@chaoming0625 `_ in `#232 `_ -* update docs by `@chaoming0625 `_ in `#233 `_ -* unify ``brainpy.layers`` with other modules in ``brainpy.dyn`` by `@chaoming0625 `_ in `#234 `_ -* fix bugs by `@chaoming0625 `_ in `#235 `_ -* update apis, docs, examples and others by `@chaoming0625 `_ in `#236 `_ -* fixes by `@chaoming0625 `_ in `#237 `_ -* fix: add dtype promotion = standard by `@ztqakita `_ in `#239 `_ -* updates by `@chaoming0625 `_ in `#240 `_ -* update training docs by `@chaoming0625 `_ in `#241 `_ -* change doc path/organization by `@chaoming0625 `_ in `#242 `_ -* Update advanced docs by `@chaoming0625 `_ in `#243 `_ -* update quickstart docs & enable jit error checking by `@chaoming0625 `_ in `#244 `_ -* update apis and examples by `@chaoming0625 `_ in `#245 `_ -* update apis and tests by `@chaoming0625 `_ in `#246 `_ -* Docs update and bugs fixed by `@ztqakita `_ in `#247 `_ -* version 2.2.0 by `@chaoming0625 `_ in `#248 `_ -* add norm and pooling & fix bugs in operators by `@ztqakita `_ in `#249 `_ - -**Full Changelog**: `V2.1.12...V2.2.0 `_ - - - - -brainpy 2.1.x -************* - - - -Version 2.1.12 (2022.05.17) -=========================== - - -Highlights -~~~~~~~~~~ - -This release is excellent. We have made important improvements. - -1. We provide dozens of random sampling in NumPy which are not - supportted in JAX, such as ``brainpy.math.random.bernoulli``, - ``brainpy.math.random.lognormal``, ``brainpy.math.random.binomial``, - ``brainpy.math.random.chisquare``, ``brainpy.math.random.dirichlet``, - ``brainpy.math.random.geometric``, ``brainpy.math.random.f``, - ``brainpy.math.random.hypergeometric``, - ``brainpy.math.random.logseries``, - ``brainpy.math.random.multinomial``, - ``brainpy.math.random.multivariate_normal``, - ``brainpy.math.random.negative_binomial``, - ``brainpy.math.random.noncentral_chisquare``, - ``brainpy.math.random.noncentral_f``, ``brainpy.math.random.power``, - ``brainpy.math.random.rayleigh``, ``brainpy.math.random.triangular``, - ``brainpy.math.random.vonmises``, ``brainpy.math.random.wald``, - ``brainpy.math.random.weibull`` -2. make efficient checking on numerical values. Instead of direct - ``id_tap()`` checking which has large overhead, currently - ``brainpy.tools.check_erro_in_jit()`` is highly efficient. -3. Fix ``JaxArray`` operator errors on ``None`` -4. improve oo-to-function transformation speeds -5. ``io`` works: ``.save_states()`` and ``.load_states()`` - -What’s Changed -~~~~~~~~~~~~~~ - -- support dtype setting in array interchange functions by - [@chaoming0625](https://github.com/chaoming0625) in - `#209 `__ -- fix `#144 `__: - operations on None raise errors by - [@chaoming0625](https://github.com/chaoming0625) in - `#210 `__ -- add tests and new functions for random sampling by - [@c-xy17](https://github.com/c-xy17) in - `#213 `__ -- feat: fix ``io`` for brainpy.Base by - [@chaoming0625](https://github.com/chaoming0625) in - `#211 `__ -- update advanced tutorial documentation by - [@chaoming0625](https://github.com/chaoming0625) in - `#212 `__ -- fix `#149 `__ - (dozens of random samplings in NumPy) and fix JaxArray op errors by - [@chaoming0625](https://github.com/chaoming0625) in - `#216 `__ -- feat: efficient checking on numerical values by - [@chaoming0625](https://github.com/chaoming0625) in - `#217 `__ - -**Full Changelog**: -`V2.1.11...V2.1.12 `__ - - - -Version 2.1.11 (2022.05.15) -=========================== - - -What's Changed -~~~~~~~~~~~~~~ - -* fix: cross-correlation bug by `@ztqakita `_ in `#201 `_ -* update apis, test and docs of numpy ops by `@chaoming0625 `_ in `#202 `_ -* docs: add sphinx_book_theme by `@ztqakita `_ in `#203 `_ -* fix: add requirements-doc.txt by `@ztqakita `_ in `#204 `_ -* update control flow, integrators, operators, and docs by `@chaoming0625 `_ in `#205 `_ -* improve oo-to-function transformation speed by `@chaoming0625 `_ in `#208 `_ - -**Full Changelog**\ : `V2.1.10...V2.1.11 `_ - - - -Version 2.1.10 (2022.05.05) -=========================== - - -What's Changed -~~~~~~~~~~~~~~ - -* update control flow APIs and Docs by `@chaoming0625 `_ in `#192 `_ -* doc: update docs of dynamics simulation by `@chaoming0625 `_ in `#193 `_ -* fix `#125 `_: add channel models and two-compartment Pinsky-Rinzel model by `@chaoming0625 `_ in `#194 `_ -* JIT errors do not change Variable values by `@chaoming0625 `_ in `#195 `_ -* fix a bug in math.activations.py by `@c-xy17 `_ in `#196 `_ -* Functionalinaty improvements by `@chaoming0625 `_ in `#197 `_ -* update rate docs by `@chaoming0625 `_ in `#198 `_ -* update brainpy.dyn doc by `@chaoming0625 `_ in `#199 `_ - -**Full Changelog**\ : `V2.1.8...V2.1.10 `_ - - - -Version 2.1.8 (2022.04.26) -========================== - - -What's Changed -~~~~~~~~~~~~~~ - -* Fix `#120 `_ by `@chaoming0625 `_ in `#178 `_ -* feat: brainpy.Collector supports addition and subtraction by `@chaoming0625 `_ in `#179 `_ -* feat: delay variables support "indices" and "reset()" function by `@chaoming0625 `_ in `#180 `_ -* Support reset functions in neuron and synapse models by `@chaoming0625 `_ in `#181 `_ -* ``update()`` function on longer need ``_t`` and ``_dt`` by `@chaoming0625 `_ in `#183 `_ -* small updates by `@chaoming0625 `_ in `#188 `_ -* feat: easier control flows with ``brainpy.math.ifelse`` by `@chaoming0625 `_ in `#189 `_ -* feat: update delay couplings of ``DiffusiveCoupling`` and ``AdditiveCouping`` by `@chaoming0625 `_ in `#190 `_ -* update version and changelog by `@chaoming0625 `_ in `#191 `_ - -**Full Changelog**\ : `V2.1.7...V2.1.8 `_ - - - -Version 2.1.7 (2022.04.22) -========================== - - -What's Changed -~~~~~~~~~~~~~~ - -* synapse models support heterogeneuos weights by `@chaoming0625 `_ in `#170 `_ -* more efficient synapse implementation by `@chaoming0625 `_ in `#171 `_ -* fix input models in brainpy.dyn by `@chaoming0625 `_ in `#172 `_ -* fix: np array astype by `@ztqakita `_ in `#173 `_ -* update README: 'brain-py' to 'brainpy' by `@chaoming0625 `_ in `#174 `_ -* fix: fix the updating rules in the STP model by `@c-xy17 `_ in `#176 `_ -* Updates and fixes by `@chaoming0625 `_ in `#177 `_ - -**Full Changelog**\ : `V2.1.5...V2.1.7 `_ - - -Version 2.1.5 (2022.04.18) -========================== - - -What's Changed -~~~~~~~~~~~~~~ - -* ``brainpy.math.random.shuffle`` is numpy like by `@chaoming0625 `_ in `#153 `_ -* update LICENSE by `@chaoming0625 `_ in `#155 `_ -* docs: add m1 warning by `@ztqakita `_ in `#154 `_ -* compatible apis of 'brainpy.math' with those of 'jax.numpy' in most modules by `@chaoming0625 `_ in `#156 `_ -* Important updates by `@chaoming0625 `_ in `#157 `_ -* Updates by `@chaoming0625 `_ in `#159 `_ -* Add LayerNorm, GroupNorm, and InstanceNorm as nn_nodes in normalization.py by `@c-xy17 `_ in `#162 `_ -* feat: add conv & pooling nodes by `@ztqakita `_ in `#161 `_ -* fix: update setup.py by `@ztqakita `_ in `#163 `_ -* update setup.py by `@chaoming0625 `_ in `#165 `_ -* fix: change trigger condition by `@ztqakita `_ in `#166 `_ -* fix: add build_conn() function by `@ztqakita `_ in `#164 `_ -* update synapses by `@chaoming0625 `_ in `#167 `_ -* get the deserved name: brainpy by `@chaoming0625 `_ in `#168 `_ -* update tests by `@chaoming0625 `_ in `#169 `_ - -**Full Changelog**\ : `V2.1.4...V2.1.5 `_ - - - -Version 2.1.4 (2022.04.04) -========================== - - -What's Changed -~~~~~~~~~~~~~~ - -* fix doc parsing bug by `@chaoming0625 `_ in `#127 `_ -* Update overview_of_dynamic_model.ipynb by `@c-xy17 `_ in `#129 `_ -* Reorganization of ``brainpylib.custom_op`` and adding interface in ``brainpy.math`` by `@ztqakita `_ in `#128 `_ -* Fix: modify ``register_op`` and brainpy.math interface by `@ztqakita `_ in `#130 `_ -* new features about RNN training and delay differential equations by `@chaoming0625 `_ in `#132 `_ -* Fix `#123 `_\ : Add low-level operators docs and modify register_op by `@ztqakita `_ in `#134 `_ -* feat: add generate_changelog by `@ztqakita `_ in `#135 `_ -* fix `#133 `_\ , support batch size training with offline algorithms by `@chaoming0625 `_ in `#136 `_ -* fix `#84 `_\ : support online training algorithms by `@chaoming0625 `_ in `#137 `_ -* feat: add the batch normalization node by `@c-xy17 `_ in `#138 `_ -* fix: fix shape checking error by `@chaoming0625 `_ in `#139 `_ -* solve `#131 `_\ , support efficient synaptic computation for special connection types by `@chaoming0625 `_ in `#140 `_ -* feat: update the API and test for batch normalization by `@c-xy17 `_ in `#142 `_ -* Node is default trainable by `@chaoming0625 `_ in `#143 `_ -* Updates training apis and docs by `@chaoming0625 `_ in `#145 `_ -* fix: add dependencies and update version by `@ztqakita `_ in `#147 `_ -* update requirements by `@chaoming0625 `_ in `#146 `_ -* data pass of the Node is default SingleData by `@chaoming0625 `_ in `#148 `_ - -**Full Changelog**\ : `V2.1.3...V2.1.4 `_ - - - -Version 2.1.3 (2022.03.27) -========================== - -This release improves the functionality and usability of BrainPy. Core changes include - -* support customization of low-level operators by using Numba -* fix bugs - -What's Changed -~~~~~~~~~~~~~~ - -* Provide custom operators written in numba for jax jit by `@ztqakita `_ in `#122 `_ -* fix DOGDecay bugs, add more features by `@chaoming0625 `_ in `#124 `_ -* fix bugs by `@chaoming0625 `_ in `#126 `_ - -**Full Changelog** : `V2.1.2...V2.1.3 `_ - - - - -Version 2.1.2 (2022.03.23) -========================== - -This release improves the functionality and usability of BrainPy. Core changes include - -- support rate-based whole-brain modeling -- add more neuron models, including rate neurons/synapses -- support Python 3.10 -- improve delays etc. APIs - - -What's Changed -~~~~~~~~~~~~~~ - -* fix matplotlib dependency on "brainpy.analysis" module by `@chaoming0625 `_ in `#110 `_ -* Sync master to brainpy-2.x branch by `@ztqakita `_ in `#111 `_ -* add py3.6 test & delete multiple macos env by `@ztqakita `_ in `#112 `_ -* Modify ci by `@ztqakita `_ in `#113 `_ -* Add py3.10 test by `@ztqakita `_ in `#115 `_ -* update python version by `@chaoming0625 `_ in `#114 `_ -* add brainpylib mac py3.10 by `@ztqakita `_ in `#116 `_ -* Enhance measure/input/brainpylib by `@chaoming0625 `_ in `#117 `_ -* fix `#105 `_\ : Add customize connections docs by `@ztqakita `_ in `#118 `_ -* fix bugs by `@chaoming0625 `_ in `#119 `_ -* Whole brain modeling by `@chaoming0625 `_ in `#121 `_ - -**Full Changelog**: `V2.1.1...V2.1.2 `_ - - -Version 2.1.1 (2022.03.18) -========================== - -This release continues to update the functionality of BrainPy. Core changes include - -- numerical solvers for fractional differential equations -- more standard ``brainpy.nn`` interfaces - - -New Features -~~~~~~~~~~~~ - -- Numerical solvers for fractional differential equations - - ``brainpy.fde.CaputoEuler`` - - ``brainpy.fde.CaputoL1Schema`` - - ``brainpy.fde.GLShortMemory`` -- Fractional neuron models - - ``brainpy.dyn.FractionalFHR`` - - ``brainpy.dyn.FractionalIzhikevich`` -- support ``shared_kwargs`` in `RNNTrainer` and `RNNRunner` - - -Version 2.1.0 (2022.03.14) -========================== - - -Highlights -~~~~~~~~~~ - -We are excited to announce the release of BrainPy 2.1.0. This release is composed of nearly -270 commits since 2.0.2, made by `Chaoming Wang `_, -`Xiaoyu Chen `_, and `Tianqiu Zhang `_ . - -BrainPy 2.1.0 updates are focused on improving usability, functionality, and stability of BrainPy. -Highlights of version 2.1.0 include: - -- New module ``brainpy.dyn`` for dynamics building and simulation. It is composed of many - neuron models, synapse models, and others. -- New module ``brainpy.nn`` for neural network building and training. It supports to - define reservoir models, artificial neural networks, ridge regression training, - and back-propagation through time training. -- New module ``brainpy.datasets`` for convenient dataset construction and initialization. -- New module ``brainpy.integrators.dde`` for numerical integration of delay differential equations. -- Add more numpy-like operators in ``brainpy.math`` module. -- Add automatic continuous integration on Linux, Windows, and MacOS platforms. -- Fully update brainpy documentation. -- Fix bugs on ``brainpy.analysis`` and ``brainpy.math.autograd`` - - -Incompatible changes -~~~~~~~~~~~~~~~~~~~~ - -- Remove ``brainpy.math.numpy`` module. -- Remove numba requirements -- Remove matplotlib requirements -- Remove `steps` in ``brainpy.dyn.DynamicalSystem`` -- Remove travis CI - - -New Features -~~~~~~~~~~~~ - -- ``brainpy.ddeint`` for numerical integration of delay differential equations, - the supported methods include: - - Euler - - MidPoint - - Heun2 - - Ralston2 - - RK2 - - RK3 - - Heun3 - - Ralston3 - - SSPRK3 - - RK4 - - Ralston4 - - RK4Rule38 -- set default int/float/complex types - - ``brainpy.math.set_dfloat()`` - - ``brainpy.math.set_dint()`` - - ``brainpy.math.set_dcomplex()`` -- Delay variables - - ``brainpy.math.FixedLenDelay`` - - ``brainpy.math.NeutralDelay`` -- Dedicated operators - - ``brainpy.math.sparse_matmul()`` -- More numpy-like operators -- Neural network building ``brainpy.nn`` -- Dynamics model building and simulation ``brainpy.dyn`` - - -Version 2.0.2 (2022.02.11) -========================== - -There are important updates by `Chaoming Wang `_ -in BrainPy 2.0.2. - -- provide ``pre2post_event_prod`` operator -- support array creation from a list/tuple of JaxArray in ``brainpy.math.asarray`` and ``brainpy.math.array`` -- update ``brainpy.ConstantDelay``, add ``.latest`` and ``.oldest`` attributes -- add ``brainpy.IntegratorRunner`` support for efficient simulation of brainpy integrators -- support auto finding of RandomState when JIT SDE integrators -- fix bugs in SDE ``exponential_euler`` method -- move ``parallel`` running APIs into ``brainpy.simulation`` -- add ``brainpy.math.syn2post_mean``, ``brainpy.math.syn2post_softmax``, - ``brainpy.math.pre2post_mean`` and ``brainpy.math.pre2post_softmax`` operators - - - -Version 2.0.1 (2022.01.31) -========================== - -Today we release BrainPy 2.0.1. This release is composed of over -70 commits since 2.0.0, made by `Chaoming Wang `_, -`Xiaoyu Chen `_, and -`Tianqiu Zhang `_ . - -BrainPy 2.0.0 updates are focused on improving documentation and operators. -Core changes include: - -- Improve ``brainpylib`` operators -- Complete documentation for programming system -- Add more numpy APIs -- Add ``jaxfwd`` in autograd module -- And other changes - - -Version 2.0.0.1 (2022.01.05) -============================ - -- Add progress bar in ``brainpy.StructRunner`` - - -Version 2.0.0 (2021.12.31) -========================== - -Start a new version of BrainPy. - -Highlight -~~~~~~~~~ - -We are excited to announce the release of BrainPy 2.0.0. This release is composed of over -260 commits since 1.1.7, made by `Chaoming Wang `_, -`Xiaoyu Chen `_, and `Tianqiu Zhang `_ . - -BrainPy 2.0.0 updates are focused on improving performance, usability and consistence of BrainPy. -All the computations are migrated into JAX. Model ``building``, ``simulation``, ``training`` -and ``analysis`` are all based on JAX. Highlights of version 2.0.0 include: - -- `brainpylib `_ are provided to dedicated operators for - brain dynamics programming -- Connection APIs in ``brainpy.conn`` module are more efficient. -- Update analysis tools for low-dimensional and high-dimensional systems in ``brainpy.analysis`` module. -- Support more general Exponential Euler methods based on automatic differentiation. -- Improve the usability and consistence of ``brainpy.math`` module. -- Remove JIT compilation based on Numba. -- Separate brain building with brain simulation. - - -Incompatible changes -~~~~~~~~~~~~~~~~~~~~ - -- remove ``brainpy.math.use_backend()`` -- remove ``brainpy.math.numpy`` module -- no longer support ``.run()`` in ``brainpy.DynamicalSystem`` (see New Features) -- remove ``brainpy.analysis.PhasePlane`` (see New Features) -- remove ``brainpy.analysis.Bifurcation`` (see New Features) -- remove ``brainpy.analysis.FastSlowBifurcation`` (see New Features) - - -New Features -~~~~~~~~~~~~ - -- Exponential Euler method based on automatic differentiation - - ``brainpy.ode.ExpEulerAuto`` -- Numerical optimization based low-dimensional analyzers: - - ``brainpy.analysis.PhasePlane1D`` - - ``brainpy.analysis.PhasePlane2D`` - - ``brainpy.analysis.Bifurcation1D`` - - ``brainpy.analysis.Bifurcation2D`` - - ``brainpy.analysis.FastSlow1D`` - - ``brainpy.analysis.FastSlow2D`` -- Numerical optimization based high-dimensional analyzer: - - ``brainpy.analysis.SlowPointFinder`` -- Dedicated operators in ``brainpy.math`` module: - - ``brainpy.math.pre2post_event_sum`` - - ``brainpy.math.pre2post_sum`` - - ``brainpy.math.pre2post_prod`` - - ``brainpy.math.pre2post_max`` - - ``brainpy.math.pre2post_min`` - - ``brainpy.math.pre2syn`` - - ``brainpy.math.syn2post`` - - ``brainpy.math.syn2post_prod`` - - ``brainpy.math.syn2post_max`` - - ``brainpy.math.syn2post_min`` -- Conversion APIs in ``brainpy.math`` module: - - ``brainpy.math.as_device_array()`` - - ``brainpy.math.as_variable()`` - - ``brainpy.math.as_jaxarray()`` -- New autograd APIs in ``brainpy.math`` module: - - ``brainpy.math.vector_grad()`` -- Simulation runners: - - ``brainpy.ReportRunner`` - - ``brainpy.StructRunner`` - - ``brainpy.NumpyRunner`` -- Commonly used models in ``brainpy.models`` module - - ``brainpy.models.LIF`` - - ``brainpy.models.Izhikevich`` - - ``brainpy.models.AdExIF`` - - ``brainpy.models.SpikeTimeInput`` - - ``brainpy.models.PoissonInput`` - - ``brainpy.models.DeltaSynapse`` - - ``brainpy.models.ExpCUBA`` - - ``brainpy.models.ExpCOBA`` - - ``brainpy.models.AMPA`` - - ``brainpy.models.GABAa`` -- Naming cache clean: ``brainpy.clear_name_cache`` -- add safe in-place operations of ``update()`` method and ``.value`` assignment for JaxArray - - -Documentation -~~~~~~~~~~~~~ - -- Complete tutorials for quickstart -- Complete tutorials for dynamics building -- Complete tutorials for dynamics simulation -- Complete tutorials for dynamics training -- Complete tutorials for dynamics analysis -- Complete tutorials for API documentation - - -brainpy 1.1.x -************* - - -If you are using ``brainpy==1.x``, you can find *documentation*, *examples*, and *models* through the following links: - -- **Documentation:** https://brainpy.readthedocs.io/en/brainpy-1.x/ -- **Examples from papers**: https://brainpy-examples.readthedocs.io/en/brainpy-1.x/ -- **Canonical brain models**: https://brainmodels.readthedocs.io/en/brainpy-1.x/ - - -Version 1.1.7 (2021.12.13) -========================== - -- fix bugs on ``numpy_array()`` conversion in `brainpy.math.utils` module - - -Version 1.1.5 (2021.11.17) -========================== - -**API changes:** - -- fix bugs on ndarray import in `brainpy.base.function.py` -- convenient 'get_param' interface `brainpy.simulation.layers` -- add more weight initialization methods - -**Doc changes:** - -- add more examples in README - - -Version 1.1.4 -============= - -**API changes:** - -- add ``.struct_run()`` in DynamicalSystem -- add ``numpy_array()`` conversion in `brainpy.math.utils` module -- add ``Adagrad``, ``Adadelta``, ``RMSProp`` optimizers -- remove `setting` methods in `brainpy.math.jax` module -- remove import jax in `brainpy.__init__.py` and enable jax setting, including - - - ``enable_x64()`` - - ``set_platform()`` - - ``set_host_device_count()`` -- enable ``b=None`` as no bias in `brainpy.simulation.layers` -- set `int_` and `float_` as default 32 bits -- remove ``dtype`` setting in Initializer constructor - -**Doc changes:** - -- add ``optimizer`` in "Math Foundation" -- add ``dynamics training`` docs -- improve others - - -Version 1.1.3 -============= - -- fix bugs of JAX parallel API imports -- fix bugs of `post_slice` structure construction -- update docs - - -Version 1.1.2 -============= - -- add ``pre2syn`` and ``syn2post`` operators -- add `verbose` and `check` option to ``Base.load_states()`` -- fix bugs on JIT DynamicalSystem (numpy backend) - - -Version 1.1.1 -============= - -- fix bugs on symbolic analysis: model trajectory -- change `absolute` access in the variable saving and loading to the `relative` access -- add UnexpectedTracerError hints in JAX transformation functions - - -Version 1.1.0 (2021.11.08) -========================== - -This package releases a new version of BrainPy. - -Highlights of core changes: - -``math`` module -~~~~~~~~~~~~~~~ - -- support numpy backend -- support JAX backend -- support ``jit``, ``vmap`` and ``pmap`` on class objects on JAX backend -- support ``grad``, ``jacobian``, ``hessian`` on class objects on JAX backend -- support ``make_loop``, ``make_while``, and ``make_cond`` on JAX backend -- support ``jit`` (based on numba) on class objects on numpy backend -- unified numpy-like ndarray operation APIs -- numpy-like random sampling APIs -- FFT functions -- gradient descent optimizers -- activation functions -- loss function -- backend settings - - -``base`` module -~~~~~~~~~~~~~~~ - -- ``Base`` for whole Version ecosystem -- ``Function`` to wrap functions -- ``Collector`` and ``TensorCollector`` to collect variables, integrators, nodes and others - - -``integrators`` module -~~~~~~~~~~~~~~~~~~~~~~ - -- class integrators for ODE numerical methods -- class integrators for SDE numerical methods - -``simulation`` module -~~~~~~~~~~~~~~~~~~~~~ - -- support modular and composable programming -- support multi-scale modeling -- support large-scale modeling -- support simulation on GPUs -- fix bugs on ``firing_rate()`` -- remove ``_i`` in ``update()`` function, replace ``_i`` with ``_dt``, - meaning the dynamic system has the canonic equation form - of :math:`dx/dt = f(x, t, dt)` -- reimplement the ``input_step`` and ``monitor_step`` in a more intuitive way -- support to set `dt` in the single object level (i.e., single instance of DynamicSystem) -- common used DNN layers -- weight initializations -- refine synaptic connections - - -brainpy 1.0.x -************* - -Version 1.0.3 (2021.08.18) -========================== - -Fix bugs on - -- firing rate measurement -- stability analysis - - -Version 1.0.2 -============= - -This release continues to improve the user-friendliness. - -Highlights of core changes: - -* Remove support for Numba-CUDA backend -* Super initialization `super(XXX, self).__init__()` can be done at anywhere - (not required to add at the bottom of the `__init__()` function). -* Add the output message of the step function running error. -* More powerful support for Monitoring -* More powerful support for running order scheduling -* Remove `unsqueeze()` and `squeeze()` operations in ``brainpy.ops`` -* Add `reshape()` operation in ``brainpy.ops`` -* Improve docs for numerical solvers -* Improve tests for numerical solvers -* Add keywords checking in ODE numerical solvers -* Add more unified operations in brainpy.ops -* Support "@every" in steps and monitor functions -* Fix ODE solver bugs for class bounded function -* Add build phase in Monitor - - -Version 1.0.1 -============= - -- Fix bugs - - -Version 1.0.0 -============= - -- **NEW VERSION OF BRAINPY** -- Change the coding style into the object-oriented programming -- Systematically improve the documentation - - -brainpy 0.x -*********** - -Version 0.3.5 -============= - -- Add 'timeout' in sympy solver in neuron dynamics analysis -- Reconstruct and generalize phase plane analysis -- Generalize the repeat mode of ``Network`` to different running duration between two runs -- Update benchmarks -- Update detailed documentation - - -Version 0.3.1 -============= - -- Add a more flexible way for NeuState/SynState initialization -- Fix bugs of "is_multi_return" -- Add "hand_overs", "requires" and "satisfies". -- Update documentation -- Auto-transform `range` to `numba.prange` -- Support `_obj_i`, `_pre_i`, `_post_i` for more flexible operation in scalar-based models - - - -Version 0.3.0 -============= - -Computation API -~~~~~~~~~~~~~~~ - -- Rename "brainpy.numpy" to "brainpy.backend" -- Delete "pytorch", "tensorflow" backends -- Add "numba" requirement -- Add GPU support - -Profile setting -~~~~~~~~~~~~~~~ - -- Delete "backend" profile setting, add "jit" - -Core systems -~~~~~~~~~~~~ - -- Delete "autopepe8" requirement -- Delete the format code prefix -- Change keywords "_t_, _dt_, _i_" to "_t, _dt, _i" -- Change the "ST" declaration out of "requires" -- Add "repeat" mode run in Network -- Change "vector-based" to "mode" in NeuType and SynType definition - -Package installation -~~~~~~~~~~~~~~~~~~~~ - -- Remove "pypi" installation, installation now only rely on "conda" - - - -Version 0.2.4 -============= - -API changes -~~~~~~~~~~~ - -- Fix bugs - - -Version 0.2.3 -============= - -API changes -~~~~~~~~~~~ - -- Add "animate_1D" in ``visualization`` module -- Add "PoissonInput", "SpikeTimeInput" and "FreqInput" in ``inputs`` module -- Update phase_portrait_analyzer.py - - -Models and examples -~~~~~~~~~~~~~~~~~~~ - -- Add CANN examples - - -Version 0.2.2 -============= - -API changes -~~~~~~~~~~~ - -- Redesign visualization -- Redesign connectivity -- Update docs - - -Version 0.2.1 -============= - -API changes -~~~~~~~~~~~ - -- Fix bugs in `numba import` -- Fix bugs in `numpy` mode with `scalar` model - - -Version 0.2.0 -============= - -API changes -~~~~~~~~~~~ - -- For computation: ``numpy``, ``numba`` -- For model definition: ``NeuType``, ``SynConn`` -- For model running: ``Network``, ``NeuGroup``, ``SynConn``, ``Runner`` -- For numerical integration: ``integrate``, ``Integrator``, ``DiffEquation`` -- For connectivity: ``One2One``, ``All2All``, ``GridFour``, ``grid_four``, - ``GridEight``, ``grid_eight``, ``GridN``, ``FixedPostNum``, ``FixedPreNum``, - ``FixedProb``, ``GaussianProb``, ``GaussianWeight``, ``DOG`` -- For visualization: ``plot_value``, ``plot_potential``, ``plot_raster``, - ``animation_potential`` -- For measurement: ``cross_correlation``, ``voltage_fluctuation``, - ``raster_plot``, ``firing_rate`` -- For inputs: ``constant_current``, ``spike_current``, ``ramp_current``. - - -Models and examples -~~~~~~~~~~~~~~~~~~~ - -- Neuron models: ``HH model``, ``LIF model``, ``Izhikevich model`` -- Synapse models: ``AMPA``, ``GABA``, ``NMDA``, ``STP``, ``GapJunction`` -- Network models: ``gamma oscillation`` - diff --git a/docs/api.rst b/docs/api.rst index 076ce48c..4e0bc42d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,7 +4,8 @@ API Documentation .. toctree:: :maxdepth: 1 - apis/auto/changelog.rst + apis/auto/brainpy-changelog.md + apis/auto/brainpylib-changelog.md apis/brainpy.rst apis/math.rst apis/dnn.rst diff --git a/docs/conf.py b/docs/conf.py index 19b1ab5b..1ff612cb 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -40,7 +40,8 @@ # sys.exit() changelogs = [ - ('../changelog.rst', 'apis/auto/changelog.rst'), + ('../brainpy-changelog.md', 'apis/auto/brainpy-changelog.md'), + ('../brainpylib-changelog.md', 'apis/auto/brainpylib-changelog.md'), ] for source, dest in changelogs: if os.path.exists(dest): diff --git a/docs/index.rst b/docs/index.rst index 732b27aa..ada4a873 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,21 +17,20 @@ Installation .. code-block:: bash - pip install -U brainpy brainpylib # windows, linux, macos + pip install -U brainpy[cpu] # windows, linux, macos - .. tab-item:: GPU (CUDA-11x) + .. tab-item:: GPU (CUDA) .. code-block:: bash - pip install -U brainpy brainpylib-cu11x # only on linux + # for CUDA 11.0, Linux only + pip install -U brainpy[cuda11] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - .. tab-item:: GPU (CUDA-12x) + # for CUDA 12.0, Linux only + pip install -U brainpy[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - .. code-block:: bash - - pip install -U brainpy brainpylib-cu12x # only on linux -For more information about supported accelerators and platforms, and for other installation details, please see `installation `_ section. +For more information, please see `installation `_ section. ---- diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index 46ce3822..6f51bfbd 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -10,8 +10,19 @@ Installation Linux, and MacOS. It only relies on Python libraries. -Minimum requirements --------------------- +Minimum requirements (without dependencies) +------------------------------------------- + +To install brainpy with minimum requirements (has installed ``jax`` and ``jaxlib`` before), you can use: + +.. code-block:: bash + + pip install brainpy # for CPU + + + +Minimum requirements (with dependencies) +---------------------------------------- To install brainpy with minimum requirements (only depends on ``jax``), you can use: @@ -21,7 +32,8 @@ To install brainpy with minimum requirements (only depends on ``jax``), you can # or - pip install brainpy[cuda_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for GPU (Linux only) + pip install brainpy[cuda11_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0 + pip install brainpy[cuda12_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0 @@ -63,18 +75,10 @@ To install the ``brainpylib`` package on CPU devices, you can run pip install brainpylib -To install the ``brainpylib`` package on CUDA 11, you can run +To install the ``brainpylib`` package on CUDA (Linux only), you can run .. code-block:: bash - pip install brainpylib-cu11x - - -To install the ``brainpylib`` package on CUDA 12, you can run - - -.. code-block:: bash - - pip install brainpylib-cu12x + pip install brainpylib diff --git a/examples/operator_customization/event_ell.py b/examples/operator_customization/event_ell.py new file mode 100644 index 00000000..0c5e7f8a --- /dev/null +++ b/examples/operator_customization/event_ell.py @@ -0,0 +1,40 @@ +import jax +import jax.numpy as jnp +import taichi as ti + +import brainpy.math as bm + + +@ti.kernel +def event_ell_cpu(indices: ti.types.ndarray(ndim=2), + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + weight_val = weight[0] + num_rows, num_cols = indices.shape + ti.loop_config(serialize=True) + for i in range(num_rows): + if vector[i]: + for j in range(num_cols): + out[indices[i, j]] += weight_val + + +prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) + + +def try_taichi_op_register(): + s = 1000 + indices = bm.random.randint(0, s, (s, 100)) + vector = bm.random.rand(s) < 0.1 + weight = bm.array([1.0]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + + # print(out) + bm.clear_buffer_memory() + + +# bm.clear_taichi_aot_caches() +try_taichi_op_register() diff --git a/setup.py b/setup.py index 885bbf57..55f948e4 100644 --- a/setup.py +++ b/setup.py @@ -69,16 +69,15 @@ ], extras_require={ 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'], - 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib-cu11x', 'numba', 'taichi==1.7.0'], - 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib-cu12x', 'numba', 'taichi==1.7.0'], + 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib', 'numba', 'taichi==1.7.0'], + 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib', 'numba', 'taichi==1.7.0'], 'tpu': ['jaxlib[tpu]', 'numba',], 'cpu_mini': ['jaxlib>=0.4.13'], - 'cuda_mini': ['jaxlib[cuda12_pip]'], + 'cuda11_mini': ['jaxlib[cuda11_pip]'], + 'cuda12_mini': ['jaxlib[cuda12_pip]'], }, keywords=('computational neuroscience, ' 'brain-inspired computation, ' - 'dynamical systems, ' - 'differential equations, ' 'brain modeling, ' 'brain dynamics modeling, ' 'brain dynamics programming'), From 165b9eef22190ac169e36d0304408097cdd0dff0 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 3 Mar 2024 14:23:15 +0800 Subject: [PATCH 11/21] Find back updates (#646) * roll back previous updates * update JIT transform * Update * Update delayvars.py * remove using internal API of jax for registering pytree objects --------- Co-authored-by: He Sichao <1310722434@qq.com> --- brainpy/_src/dnn/conv.py | 11 +- brainpy/_src/dnn/tests/test_activation.py | 2 +- brainpy/_src/dnn/tests/test_conv_layers.py | 11 +- brainpy/_src/dnn/tests/test_function.py | 6 +- brainpy/_src/dnn/tests/test_normalization.py | 3 +- brainpy/_src/dnn/tests/test_pooling_layers.py | 2 +- brainpy/_src/math/delayvars.py | 5 +- .../_src/math/object_transform/autograd.py | 45 +- brainpy/_src/math/object_transform/base.py | 6 +- .../_src/math/object_transform/controls.py | 136 +++--- brainpy/_src/math/object_transform/jit.py | 129 +++-- brainpy/_src/math/object_transform/naming.py | 10 +- .../_src/math/object_transform/parallels.py | 460 ------------------ brainpy/_src/math/object_transform/tools.py | 75 ++- .../_src/math/object_transform/variables.py | 45 +- brainpy/_src/tools/functions.py | 192 ++++++++ brainpy/_src/tools/tests/test_functions.py | 24 + brainpy/math/compat_pytorch.py | 2 +- brainpy/math/oo_transform.py | 4 +- brainpy/tools.py | 4 + docs/advanced_tutorials.rst | 51 +- docs/apis/brainpy.math.oo_transform.rst | 1 + docs/toolboxes.rst | 38 +- docs/tutorials.rst | 77 ++- 24 files changed, 610 insertions(+), 729 deletions(-) delete mode 100644 brainpy/_src/math/object_transform/parallels.py create mode 100644 brainpy/_src/tools/functions.py create mode 100644 brainpy/_src/tools/tests/test_functions.py diff --git a/brainpy/_src/dnn/conv.py b/brainpy/_src/dnn/conv.py index deead1f3..e4b6e25d 100644 --- a/brainpy/_src/dnn/conv.py +++ b/brainpy/_src/dnn/conv.py @@ -160,7 +160,7 @@ def update(self, x): nonbatching = False if x.ndim == self.num_spatial_dims + 1: nonbatching = True - x = x.unsqueeze(0) + x = bm.unsqueeze(x, 0) w = self.w.value if self.mask is not None: try: @@ -190,6 +190,9 @@ def __repr__(self): class Conv1d(_GeneralConv): """One-dimensional convolution. + The input should a 2d array with the shape of ``[H, C]``, or + a 3d array with the shape of ``[B, H, C]``, where ``H`` is the feature size. + Parameters ---------- in_channels: int @@ -282,6 +285,9 @@ def _check_input_dim(self, x): class Conv2d(_GeneralConv): """Two-dimensional convolution. + The input should a 3d array with the shape of ``[H, W, C]``, or + a 4d array with the shape of ``[B, H, W, C]``. + Parameters ---------- in_channels: int @@ -375,6 +381,9 @@ def _check_input_dim(self, x): class Conv3d(_GeneralConv): """Three-dimensional convolution. + The input should a 3d array with the shape of ``[H, W, D, C]``, or + a 4d array with the shape of ``[B, H, W, D, C]``. + Parameters ---------- in_channels: int diff --git a/brainpy/_src/dnn/tests/test_activation.py b/brainpy/_src/dnn/tests/test_activation.py index ba2a49ef..7a0fa57a 100644 --- a/brainpy/_src/dnn/tests/test_activation.py +++ b/brainpy/_src/dnn/tests/test_activation.py @@ -1,5 +1,5 @@ -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/dnn/tests/test_conv_layers.py b/brainpy/_src/dnn/tests/test_conv_layers.py index 3c9fdfa8..05f52362 100644 --- a/brainpy/_src/dnn/tests/test_conv_layers.py +++ b/brainpy/_src/dnn/tests/test_conv_layers.py @@ -1,17 +1,15 @@ # -*- coding: utf-8 -*- -from unittest import TestCase -from absl.testing import absltest import jax.numpy as jnp -import brainpy.math as bm +from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp import brainpy.math as bm class TestConv(parameterized.TestCase): def test_Conv2D_img(self): - bm.random.seed() img = jnp.zeros((2, 200, 198, 4)) for k in range(4): x = 30 + 60 * k @@ -24,6 +22,7 @@ def test_Conv2D_img(self): strides=(2, 1), padding='VALID', groups=4) out = net(img) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 99, 196, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(img)[0, :, :, 0]) @@ -31,7 +30,6 @@ def test_Conv2D_img(self): bm.clear_buffer_memory() def test_conv1D(self): - bm.random.seed() with bp.math.training_environment(): model = bp.layers.Conv1d(in_channels=3, out_channels=32, kernel_size=(3,)) @@ -39,6 +37,7 @@ def test_conv1D(self): out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :]) @@ -54,6 +53,7 @@ def test_conv2D(self): out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 32)) # print("First output channel:") # plt.figure(figsize=(10, 10)) # plt.imshow(np.array(out)[0, :, :, 31]) @@ -67,6 +67,7 @@ def test_conv3D(self): input = bp.math.ones((2, 5, 5, 5, 3)) out = model(input) print("out shape: ", out.shape) + self.assertEqual(out.shape, (2, 5, 5, 5, 32)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_function.py b/brainpy/_src/dnn/tests/test_function.py index 269fec44..9ad15938 100644 --- a/brainpy/_src/dnn/tests/test_function.py +++ b/brainpy/_src/dnn/tests/test_function.py @@ -1,12 +1,10 @@ # -*- coding: utf-8 -*- -from unittest import TestCase - -import jax.numpy as jnp -import brainpy.math as bm from absl.testing import absltest from absl.testing import parameterized + import brainpy as bp +import brainpy.math as bm class TestFunction(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_normalization.py b/brainpy/_src/dnn/tests/test_normalization.py index fdc5b34e..e76b3616 100644 --- a/brainpy/_src/dnn/tests/test_normalization.py +++ b/brainpy/_src/dnn/tests/test_normalization.py @@ -1,7 +1,8 @@ -import brainpy.math as bm from absl.testing import parameterized from absl.testing import absltest + import brainpy as bp +import brainpy.math as bm class Test_Normalization(parameterized.TestCase): diff --git a/brainpy/_src/dnn/tests/test_pooling_layers.py b/brainpy/_src/dnn/tests/test_pooling_layers.py index 34f8f5cd..5748edd8 100644 --- a/brainpy/_src/dnn/tests/test_pooling_layers.py +++ b/brainpy/_src/dnn/tests/test_pooling_layers.py @@ -3,8 +3,8 @@ import jax import jax.numpy as jnp import numpy as np -from absl.testing import parameterized from absl.testing import absltest +from absl.testing import parameterized import brainpy as bp import brainpy.math as bm diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index eb8e27c8..676e4286 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -11,7 +11,7 @@ from brainpy import check from brainpy.check import is_float, is_integer, jit_error from brainpy.errors import UnsupportedError -from .compat_numpy import vstack, broadcast_to +from .compat_numpy import broadcast_to, expand_dims, concatenate from .environment import get_dt, get_float from .interoperability import as_jax from .ndarray import ndarray, Array @@ -392,6 +392,7 @@ def reset( dtype=delay_target.dtype), batch_axis=batch_axis) else: + self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) @@ -472,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) + self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) else: self.data[:] = value diff --git a/brainpy/_src/math/object_transform/autograd.py b/brainpy/_src/math/object_transform/autograd.py index f5e09167..ad8a5ccf 100644 --- a/brainpy/_src/math/object_transform/autograd.py +++ b/brainpy/_src/math/object_transform/autograd.py @@ -28,10 +28,8 @@ get_stack_cache, cache_stack) from .base import (BrainPyObject, ObjectTransform) -from .variables import (Variable, - VariableStack, - current_transform_number, - new_transform) +from .variables import (Variable, VariableStack) +from .tools import eval_shape __all__ = [ 'grad', # gradient of scalar function @@ -203,36 +201,21 @@ def __call__(self, *args, **kwargs): elif not self._eval_dyn_vars: # evaluate dynamical variables stack = get_stack_cache(self.target) if stack is None: - with new_transform(self): - with VariableStack() as stack: - if current_transform_number() > 1: - rets = self._transform( - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs - ) - else: - rets = jax.eval_shape( - self._transform, - [v.value for v in self._grad_vars], # variables for gradients - {}, # dynamical variables - *args, - **kwargs - ) + with VariableStack() as stack: + rets = eval_shape(self._transform, + [v.value for v in self._grad_vars], # variables for gradients + {}, # dynamical variables + *args, + **kwargs) cache_stack(self.target, stack) - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + self._dyn_vars = stack + self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) + self._eval_dyn_vars = True - # if not the outermost transformation - if current_transform_number(): - return self._return(rets) - else: - self._dyn_vars = stack - self._dyn_vars.remove_by_id(*[id(v) for v in self._grad_vars]) - self._eval_dyn_vars = True + # if not the outermost transformation + if not stack.is_first_stack(): + return self._return(rets) rets = self._transform( [v.value for v in self._grad_vars], # variables for gradients diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index b21ed2af..936f6238 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -12,7 +12,6 @@ import jax import numpy as np -from jax._src.tree_util import _registry from jax.tree_util import register_pytree_node_class from brainpy._src.math.modes import Mode @@ -27,6 +26,8 @@ variable_ = None StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) +registered = set() + __all__ = [ 'BrainPyObject', 'Base', 'FunAsObject', 'ObjectTransform', @@ -91,8 +92,9 @@ def __init__(self, name=None): super().__init__() if defaults.bp_object_as_pytree: - if self.__class__ not in _registry: + if self.__class__ not in registered: register_pytree_node_class(self.__class__) + registered.add(self.__class__) # check whether the object has a unique name. self._name = None diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 032a0fab..3edeb08e 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -21,17 +21,12 @@ cache_stack ) from .tools import ( - evaluate_dyn_vars, + eval_shape, dynvar_deprecation, node_deprecation, abstract ) -from .variables import ( - Variable, - VariableStack, - new_transform, - current_transform_number, -) +from .variables import (Variable, VariableStack) __all__ = [ 'make_loop', @@ -542,15 +537,13 @@ def cond( node_deprecation(child_objs) dyn_vars = get_stack_cache((true_fun, false_fun)) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('cond'): - dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars = dyn_vars1 + dyn_vars2 - cache_stack((true_fun, false_fun), dyn_vars) - if current_transform_number() > 0: - return rets + if not jax.config.jax_disable_jit and dyn_vars is None: + with VariableStack() as dyn_vars: + rets = eval_shape(true_fun, *operands, with_stack=True)[1] + _ = eval_shape(false_fun, *operands, with_stack=True) + cache_stack((true_fun, false_fun), dyn_vars) + if not dyn_vars.is_first_stack(): + return rets dyn_vars = VariableStack() if dyn_vars is None else dyn_vars dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands) for k in dyn_values.keys(): @@ -681,20 +674,16 @@ def ifelse( else: dyn_vars = get_stack_cache(tuple(branches)) if dyn_vars is None: - with new_transform('ifelse'): - with VariableStack() as dyn_vars: - if current_transform_number() > 1: - rets = [branch(*operands) for branch in branches] - else: - rets = [jax.eval_shape(branch, *operands) for branch in branches] - trees = [jax.tree_util.tree_structure(ret) for ret in rets] - if not _all_equal(trees): - msg = 'All returns in branches should have the same tree structure. But we got:\n' - for tree in trees: - msg += f'- {tree}\n' - raise TypeError(msg) + with VariableStack() as dyn_vars: + rets = [eval_shape(fun, *operands, with_stack=True)[1] for fun in branches] + trees = [jax.tree_util.tree_structure(ret) for ret in rets] + if not _all_equal(trees): + msg = 'All returns in branches should have the same tree structure. But we got:\n' + for tree in trees: + msg += f'- {tree}\n' + raise TypeError(msg) cache_stack(tuple(branches), dyn_vars) - if current_transform_number(): + if not dyn_vars.is_first_stack(): return rets[0] branches = [_cond_transform_fun(fun, dyn_vars) for fun in branches] @@ -880,28 +869,23 @@ def for_loop( if jit is None: # jax disable jit jit = not jax.config.jax_disable_jit - dyn_vars = get_stack_cache((body_fun, unroll_kwargs)) + stack = get_stack_cache((body_fun, unroll_kwargs)) if jit: - if dyn_vars is None: + if stack is None: + transform = _get_for_loop_transform(body_fun, VariableStack(), bar, progress_bar, + remat, reverse, unroll, unroll_kwargs) # TODO: better cache mechanism? - with new_transform('for_loop'): - with VariableStack() as dyn_vars: - transform = _get_for_loop_transform(body_fun, VariableStack(), bar, - progress_bar, remat, reverse, unroll, - unroll_kwargs) - if current_transform_number() > 1: - rets = transform(operands) - else: - rets = jax.eval_shape(transform, operands) - cache_stack((body_fun, unroll_kwargs), dyn_vars) # cache - if current_transform_number(): + with VariableStack() as stack: + rets = eval_shape(transform, operands) + cache_stack((body_fun, unroll_kwargs), stack) # cache + if not stack.is_first_stack(): return rets[1] del rets else: - dyn_vars = VariableStack() + stack = VariableStack() # TODO: cache mechanism? - transform = _get_for_loop_transform(body_fun, dyn_vars, bar, + transform = _get_for_loop_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll, unroll_kwargs) if jit: @@ -909,11 +893,11 @@ def for_loop( else: with jax.disable_jit(): dyn_vals, out_vals = transform(operands) - for key in dyn_vars.keys(): - dyn_vars[key]._value = dyn_vals[key] + for key in stack.keys(): + stack[key]._value = dyn_vals[key] if progress_bar: bar.close() - del dyn_vals, dyn_vars + del dyn_vals, stack return out_vals @@ -1011,26 +995,21 @@ def scan( num_total = min([op.shape[0] for op in jax.tree_util.tree_flatten(operands)[0]]) bar = tqdm(total=num_total) - dyn_vars = get_stack_cache(body_fun) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('scan'): - with VariableStack() as dyn_vars: - transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) - if current_transform_number() > 1: - rets = transform(init, operands) - else: - rets = jax.eval_shape(transform, init, operands) - cache_stack(body_fun, dyn_vars) # cache - if current_transform_number(): - return rets[0][1], rets[1] - del rets - - dyn_vars = VariableStack() if dyn_vars is None else dyn_vars - transform = _get_scan_transform(body_fun, dyn_vars, bar, progress_bar, remat, reverse, unroll) + stack = get_stack_cache(body_fun) + if not jax.config.jax_disable_jit and stack is None: + transform = _get_scan_transform(body_fun, VariableStack(), bar, progress_bar, remat, reverse, unroll) + with VariableStack() as stack: + rets = eval_shape(transform, init, operands) + cache_stack(body_fun, stack) # cache + if not stack.is_first_stack(): + return rets[0][1], rets[1] + del rets + + stack = VariableStack() if stack is None else stack + transform = _get_scan_transform(body_fun, stack, bar, progress_bar, remat, reverse, unroll) (dyn_vals, carry), out_vals = transform(init, operands) - for key in dyn_vars.keys(): - dyn_vars[key]._value = dyn_vals[key] + for key in stack.keys(): + stack[key]._value = dyn_vals[key] if progress_bar: bar.close() return carry, out_vals @@ -1129,7 +1108,6 @@ def while_loop( No longer need to provide ``child_objs``. This function is capable of automatically collecting the children objects used in the target ``func``. - """ dynvar_deprecation(dyn_vars) node_deprecation(child_objs) @@ -1137,18 +1115,16 @@ def while_loop( if not isinstance(operands, (list, tuple)): operands = (operands,) - dyn_vars = get_stack_cache((body_fun, cond_fun)) - if not jax.config.jax_disable_jit: - if dyn_vars is None: - with new_transform('while_loop'): - dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1) - dyn_vars = dyn_vars1 + dyn_vars2 - cache_stack((body_fun, cond_fun), dyn_vars) - if current_transform_number(): - return rets - dyn_vars = VariableStack() if dyn_vars is None else dyn_vars - dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands) - for k, v in dyn_vars.items(): + stack = get_stack_cache((body_fun, cond_fun)) + if not jax.config.jax_disable_jit and stack is None: + with VariableStack() as stack: + _ = eval_shape(cond_fun, *operands, with_stack=True) + rets = eval_shape(body_fun, *operands, with_stack=True)[1] + cache_stack((body_fun, cond_fun), stack) + if not stack.is_first_stack(): + return rets + stack = VariableStack() if stack is None else stack + dyn_values, out = _get_while_transform(cond_fun, body_fun, stack)(operands) + for k, v in stack.items(): v._value = dyn_values[k] return out diff --git a/brainpy/_src/math/object_transform/jit.py b/brainpy/_src/math/object_transform/jit.py index 7bb36f4e..551a0949 100644 --- a/brainpy/_src/math/object_transform/jit.py +++ b/brainpy/_src/math/object_transform/jit.py @@ -11,23 +11,15 @@ from typing import Callable, Union, Optional, Sequence, Dict, Any, Iterable import jax -from jax.sharding import Sharding from brainpy import tools, check -from .tools import (dynvar_deprecation, - node_deprecation, - evaluate_dyn_vars_with_cache, - evaluate_dyn_vars, - _partial_fun) from .base import BrainPyObject, ObjectTransform from .naming import get_stack_cache, cache_stack +from .tools import (dynvar_deprecation, + node_deprecation, + eval_shape) +from .variables import (Variable, VariableStack) from ..ndarray import Array -from .variables import (Variable, - VariableStack, - outermost_transform, - transform_stack, - current_transform_number, - new_transform) RandomState = None @@ -96,6 +88,17 @@ def _seq_of_str(static_argnames): return static_argnames +def _jit_call_take_care_of_rngs(transform, stack, *args, **kwargs): + # call the transformed function + rng_keys = stack.call_on_subset(_is_rng, _rng_split_key) + changes, out = transform(stack.dict_data(), *args, **kwargs) + for key, v in changes.items(): + stack[key]._value = v + for key, v in rng_keys.items(): + stack[key]._value = v + return out + + class JITTransform(ObjectTransform): """Object-oriented JIT transformation in BrainPy.""" @@ -142,25 +145,21 @@ def __init__( # OO transformation parameters self._transform = None self._dyn_vars = None - - def _transform_function(self, variable_data: Dict, *args, **kwargs): - for key, v in self._dyn_vars.items(): - v._value = variable_data[key] - out = self.fun(*args, **kwargs) - changes = self._dyn_vars.dict_data_of_subset(_is_not_rng) - return changes, out + # + # def _transform_function(self, variable_data: Dict, *args, **kwargs): + # for key, v in self._dyn_vars.items(): + # v._value = variable_data[key] + # out = self.fun(*args, **kwargs) + # changes = self._dyn_vars.dict_data_of_subset(_is_not_rng) + # return changes, out def _get_transform(self, *args, **kwargs): - with new_transform(self): - self._dyn_vars, rets = evaluate_dyn_vars( - self.fun, - *args, - static_argnums=self._static_argnums, - static_argnames=self._static_argnames, - use_eval_shape=current_transform_number() <= 1, - **kwargs - ) - + with VariableStack() as self._dyn_vars: + rets = eval_shape(self.fun, + *args, + **kwargs, + static_argnums=self._static_argnums, + static_argnames=self._static_argnames) # in_shardings if self._in_shardings is None: in_shardings = None @@ -186,18 +185,18 @@ def _get_transform(self, *args, **kwargs): _dyn_vars_sharing = get_shardings(self._dyn_vars.subset_by_not_instance(RandomState)) out_shardings = (_dyn_vars_sharing,) + out_shardings - # jit - self._transform = jax.jit( - self._transform_function, - static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), - static_argnames=self._static_argnames, - donate_argnums=self._donate_argnums, - inline=self._inline, - keep_unused=self._keep_unused, - abstracted_axes=self._abstracted_axes, - in_shardings=in_shardings, - out_shardings=out_shardings, - ) + # jit + self._transform = jax.jit( + _make_transform(self.fun, self._dyn_vars), + static_argnums=jax.tree_util.tree_map(lambda a: a + 1, self._static_argnums), + static_argnames=self._static_argnames, + donate_argnums=self._donate_argnums, + inline=self._inline, + keep_unused=self._keep_unused, + abstracted_axes=self._abstracted_axes, + in_shardings=in_shardings, + out_shardings=out_shardings, + ) return rets def __call__(self, *args, **kwargs): @@ -207,17 +206,11 @@ def __call__(self, *args, **kwargs): if self._transform is None: # initialize the transformation rets = self._get_transform(*args, **kwargs) # if not the outermost transformation - if current_transform_number(): + if not self._dyn_vars.is_first_stack(): return rets # call the transformed function - rng_keys = self._dyn_vars.call_on_subset(_is_rng, _rng_split_key) - changes, out = self._transform(self._dyn_vars.dict_data(), *args, **kwargs) - for key, v in changes.items(): - self._dyn_vars[key]._value = v - for key, v in rng_keys.items(): - self._dyn_vars[key]._value = v - return out + return _jit_call_take_care_of_rngs(self._transform, self._dyn_vars, *args, **kwargs) def __repr__(self): name = self.__class__.__name__ @@ -314,7 +307,7 @@ def jit( Examples -------- - You can JIT any object in which all dynamical variables are defined as :py:class:`~.Variable`. + You can JIT any object in which all dynamical variables are defined as :py:class:`~.Variable`. >>> import brainpy as bp >>> class Hello(bp.BrainPyObject): @@ -401,12 +394,12 @@ def cls_jit( **kwargs ) -> Callable: """Just-in-time compile a function and then the jitted function as the bound method for a class. - + Examples -------- - + This transformation can be put on any class function. For example, - + >>> import brainpy as bp >>> import brainpy.math as bm >>> @@ -415,7 +408,7 @@ def cls_jit( >>> super(SomeProgram, self).__init__() >>> self.a = bm.zeros(2) >>> self.b = bm.Variable(bm.ones(2)) - >>> + >>> >>> @bm.cls_jit(inline=True) >>> def __call__(self, *args, **kwargs): >>> a = bm.random.uniform(size=2) @@ -424,7 +417,7 @@ def cls_jit( >>> >>> program = SomeProgram() >>> program() - + Parameters ---------- {jit_pars} @@ -477,15 +470,8 @@ def call_fun(self, *args, **kwargs): cache = get_stack_cache(hash_v) # TODO: better cache mechanism if cache is None: fun2 = partial(fun, self) - - with jax.ensure_compile_time_eval(): - if len(static_argnums) or len(static_argnames): - fun3, args_, kwargs_ = _partial_fun(fun2, args, kwargs, static_argnums, static_argnames) - else: - args_, kwargs_, fun3 = args, kwargs, fun2 - with VariableStack() as stack: - _ = jax.eval_shape(fun3, *args_, **kwargs_) - del args_, kwargs_ + with VariableStack() as stack: + out = eval_shape(fun2, *args, **kwargs, static_argnums=static_argnums, static_argnames=static_argnames) _transform = jax.jit( _make_transform(fun2, stack), static_argnums=jax.tree_util.tree_map(lambda a: a + 1, static_argnums), @@ -497,25 +483,22 @@ def call_fun(self, *args, **kwargs): **jit_kwargs ) cache_stack(hash_v, (stack, _transform)) # cache "variable stack" and "transform function" - + if not stack.is_first_stack(): + return out else: stack, _transform = cache - del cache - out, changes = _transform(stack.dict_data(), *args, **kwargs) - for key, v in stack.items(): - v._value = changes[key] - return out + return _jit_call_take_care_of_rngs(_transform, stack, *args, **kwargs) return call_fun def _make_transform(fun, stack): @wraps(fun) - def _transform_function(variable_data: dict, *args, **kwargs): + def _transform_function(variable_data: Dict, *args, **kwargs): for key, v in stack.items(): v._value = variable_data[key] out = fun(*args, **kwargs) - changes = stack.dict_data() - return out, changes + changes = stack.dict_data_of_subset(_is_not_rng) + return changes, out return _transform_function diff --git a/brainpy/_src/math/object_transform/naming.py b/brainpy/_src/math/object_transform/naming.py index 6326929c..1181e003 100644 --- a/brainpy/_src/math/object_transform/naming.py +++ b/brainpy/_src/math/object_transform/naming.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -import gc + import warnings from brainpy import errors @@ -11,7 +11,6 @@ _name2id = dict() _typed_names = {} -_fun2stack = dict() def check_name_uniqueness(name, obj): @@ -42,7 +41,7 @@ def get_unique_name(type_: str): return name -def clear_name_cache(ignore_warn=False): +def clear_name_cache(ignore_warn=True): """Clear the cached names.""" _name2id.clear() _typed_names.clear() @@ -50,14 +49,17 @@ def clear_name_cache(ignore_warn=False): warnings.warn(f'All named models and their ids are cleared.', UserWarning) +_fun2stack = dict() + + def cache_stack(func, stack): _fun2stack[func] = stack def clear_stack_cache(): + """Clear the cached stack.""" for k in tuple(_fun2stack.keys()): del _fun2stack[k] - gc.collect() def get_stack_cache(func): diff --git a/brainpy/_src/math/object_transform/parallels.py b/brainpy/_src/math/object_transform/parallels.py deleted file mode 100644 index 1eddce04..00000000 --- a/brainpy/_src/math/object_transform/parallels.py +++ /dev/null @@ -1,460 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -The parallel compilation tools for JAX backend. - -1. Vectorize compilation is implemented by the 'vmap()' function -2. Parallel compilation is implemented by the 'pmap()' function - -""" - - -import functools - -import jax -import jax.numpy as jnp -import numpy as np -from jax.interpreters.partial_eval import DynamicJaxprTracer -from jax.interpreters.partial_eval import JaxprTracer -from jax.interpreters.pxla import ShardedDeviceArray - -try: - from jax.errors import UnexpectedTracerError -except ImportError: - from jax.core import UnexpectedTracerError - -from brainpy import errors -from brainpy._src.math.random import RandomState -from brainpy._src.math.ndarray import Array -from brainpy.tools import change_func_name -from .base import BrainPyObject, ArrayCollector - -__all__ = [ - 'vmap', - 'pmap', -] - - -def _make_vmap(func, nonbatched_vars, batched_vars, in_axes, out_axes, - batch_idx, axis_name, f_name=None): - @functools.partial(jax.vmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name) - def vmapped_func(nonbatched_data, batched_data, *args, **kwargs): - nonbatched_vars.assign(nonbatched_data) - batched_vars.assign(batched_data) - out = func(*args, **kwargs) - nonbatched_changes = nonbatched_vars.dict() - batched_changes = batched_vars.dict() - return nonbatched_changes, batched_changes, out - - def call(*args, **kwargs): - n = args[batch_idx[0]].shape[batch_idx[1]] - nonbatched_data = nonbatched_vars.dict() - batched_data = {key: val.split_keys(n) for key, val in batched_vars.items()} - try: - out, dyn_changes, rand_changes = vmapped_func(nonbatched_data, batched_data, *args, **kwargs) - except UnexpectedTracerError as e: - nonbatched_vars.assign(nonbatched_data) - batched_vars.assign(batched_data) - raise errors.JaxTracerError() from e - # for key, v in dyn_changes.items(): - # dyn_vars[key] = reduce_func(v) - # for key, v in rand_changes.items(): - # rand_vars[key] = reduce_func(v) - return out - - return change_func_name(name=f_name, f=call) if f_name else call - - -def vmap(func, dyn_vars=None, batched_vars=None, - in_axes=0, out_axes=0, axis_name=None, - reduce_func=None, auto_infer=False): - """Vectorization compilation for class objects. - - Vectorized compile a function or a module to run in parallel on a single device. - - Examples - -------- - - Parameters - ---------- - func : BrainPyObject, function, callable - The function or the module to compile. - dyn_vars : dict, sequence - batched_vars : dict - in_axes : optional, int, sequence of int - Specify which input array axes to map over. If each positional argument to - ``obj_or_func`` is an array, then ``in_axes`` can be an integer, a None, - or a tuple of integers and Nones with length equal to the number of - positional arguments to ``obj_or_func``. An integer or ``None`` - indicates which array axis to map over for all arguments (with ``None`` - indicating not to map any axis), and a tuple indicates which axis to map - for each corresponding positional argument. Axis integers must be in the - range ``[-ndim, ndim)`` for each array, where ``ndim`` is the number of - dimensions (axes) of the corresponding input array. - - If the positional arguments to ``obj_or_func`` are container types, the - corresponding element of ``in_axes`` can itself be a matching container, - so that distinct array axes can be mapped for different container - elements. ``in_axes`` must be a container tree prefix of the positional - argument tuple passed to ``obj_or_func``. - - At least one positional argument must have ``in_axes`` not None. The sizes - of the mapped input axes for all mapped positional arguments must all be - equal. - - Arguments passed as keywords are always mapped over their leading axis - (i.e. axis index 0). - out_axes : optional, int, tuple/list/dict - Indicate where the mapped axis should appear in the output. All outputs - with a mapped axis must have a non-None ``out_axes`` specification. Axis - integers must be in the range ``[-ndim, ndim)`` for each output array, - where ``ndim`` is the number of dimensions (axes) of the array returned - by the :func:`vmap`-ed function, which is one more than the number of - dimensions (axes) of the corresponding array returned by ``obj_or_func``. - axis_name : optional - - Returns - ------- - obj_or_func : Any - Batched/vectorized version of ``obj_or_func`` with arguments that correspond to - those of ``obj_or_func``, but with extra array axes at positions indicated by - ``in_axes``, and a return value that corresponds to that of ``obj_or_func``, but - with extra array axes at positions indicated by ``out_axes``. - - """ - # if isinstance(func, DynamicalSystem): - # if len(func.steps): # DynamicalSystem has step functions - # - # # dynamical variables - # dyn_vars = (dyn_vars or func.vars().unique()) - # dyn_vars, rand_vars = ArrayCollector(), ArrayCollector() - # for key, val in dyn_vars.items(): - # if isinstance(val, RandomState): - # rand_vars[key] = val - # else: - # dyn_vars[key] = val - # - # # in axes - # if in_axes is None: - # in_axes = {key: (None, 0) for key in func.steps.keys()} - # elif isinstance(in_axes, int): - # in_axes = {key: (None, 0, in_axes) for key in func.steps.keys()} - # elif isinstance(in_axes, (tuple, list)): - # in_axes = {key: (None, 0) + tuple(in_axes) for key in func.steps.keys()} - # elif isinstance(in_axes, dict): - # keys = list(func.steps.keys()) - # if keys[0] not in in_axes: - # in_axes = {key: (None, 0, in_axes) for key in keys} - # else: - # in_axes = {key: (None, 0) + tuple(in_axes[key]) for key in keys} - # assert isinstance(in_axes, dict) - # - # # batch size index - # batch_idx = {} - # for key, axes in in_axes.items(): - # for i, axis in enumerate(axes[2:]): - # if axis is not None: - # batch_idx[key] = (i, axis) - # break - # else: - # raise ValueError(f'Found no batch axis: {axes}.') - # - # # out axes - # if out_axes is None: - # out_axes = {key: 0 for key in func.steps.keys()} - # elif isinstance(out_axes, int): - # out_axes = {key: out_axes for key in func.steps.keys()} - # elif isinstance(out_axes, (tuple, list)): - # out_axes = {key: tuple(out_axes) + (0, 0) for key in func.steps.keys()} - # elif isinstance(out_axes, dict): - # keys = list(func.steps.keys()) - # if keys[0] not in out_axes: - # out_axes = {key: (out_axes, 0, 0) for key in keys} - # else: - # out_axes = {key: tuple(out_axes[key]) + (0, 0) for key in keys} - # assert isinstance(out_axes, dict) - # - # # reduce_func - # if reduce_func is None: - # reduce_func = lambda x: x.mean(axis=0) - # - # # vectorized map functions - # for key in func.steps.keys(): - # func.steps[key] = _make_vmap(func=func.steps[key], - # dyn_vars=dyn_vars, - # rand_vars=rand_vars, - # in_axes=in_axes[key], - # out_axes=out_axes[key], - # axis_name=axis_name, - # batch_idx=batch_idx[key], - # reduce_func=reduce_func, - # f_name=key) - # - # return func - - if callable(func): - if auto_infer: - if dyn_vars is not None: - dyn_vars = dyn_vars - elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation - dyn_vars = func.vars().unique() - elif hasattr(func, '__self__'): - if isinstance(func.__self__, BrainPyObject): - dyn_vars = func.__self__.vars().unique() - - if dyn_vars is None: - return jax.vmap(func, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name) - - else: - if isinstance(dyn_vars, Array): - dyn_vars = [dyn_vars] - if isinstance(dyn_vars, (tuple, list)): - dyn_vars = {f'_vmap_v{i}': v for i, v in enumerate(dyn_vars)} - assert isinstance(dyn_vars, dict) - - # dynamical variables - _dyn_vars, _rand_vars = ArrayCollector(), ArrayCollector() - for key, val in dyn_vars.items(): - if isinstance(val, RandomState): - _rand_vars[key] = val - else: - _dyn_vars[key] = val - - # in axes - if in_axes is None: - in_axes = (None, 0) - elif isinstance(in_axes, (int, dict)): - in_axes = (None, 0, in_axes) - elif isinstance(in_axes, (tuple, list)): - in_axes = (None, 0) + tuple(in_axes) - assert isinstance(in_axes, (tuple, list)) - - # batch size index - batch_idx = {} - for key, axes in batch_idx.items(): - for i, axis in enumerate(axes[2:]): - if axis is not None: - batch_idx[key] = (i, axis) - break - else: - raise ValueError(f'Found no batch axis: {axes}.') - - # out axes - if out_axes is None: - out_axes = 0 - elif isinstance(out_axes, (int, dict)): - out_axes = (out_axes, 0, 0) - elif isinstance(out_axes, (tuple, list)): - out_axes = tuple(out_axes) + (0, 0) - assert isinstance(out_axes, (list, tuple)) - - # reduce_func - if reduce_func is None: - reduce_func = lambda x: x.mean(axis=0) - - # jit function - return _make_vmap(func=func, - nonbatched_vars=_dyn_vars, - batched_vars=_rand_vars, - in_axes=in_axes, - out_axes=out_axes, - axis_name=axis_name, - batch_idx=batch_idx) - - else: - raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable ' - f'function, but we got {type(func)}.') - - -def _device_reshape(x): - """Reshape an input array in order to broadcast to multiple devices.""" - num_device = jax.local_device_count() - - if not hasattr(x, 'ndim'): - raise errors.BrainPyError(f'Expected Array, got {type(x)}. If you are trying to pass a scalar to ' - f'parallel, first convert it to a Array, for example np.float(0.5)') - if x.ndim == 0: - return np.broadcast_to(x, [num_device]) - if x.shape[0] % num_device != 0: - raise errors.BrainPyError(f'Must be able to equally divide batch {x.shape} among ' - f'{num_device} devices, but does not go equally.') - return x.reshape((num_device, x.shape[0] // num_device) + x.shape[1:]) - - -def _make_pmap(func, dyn_vars, rand_vars, reduce_func, axis_name=None, in_axes=0, - out_axes=0, static_broadcasted_argnums=(), devices=None, backend=None, - axis_size=None, donate_argnums=(), global_arg_shapes=None, f_name=None): - @functools.partial(jax.pmap, in_axes=in_axes, out_axes=out_axes, axis_name=axis_name, - static_broadcasted_argnums=static_broadcasted_argnums, devices=devices, - backend=backend, axis_size=axis_size, donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) - def pmapped_func(dyn_data, rand_data, *args, **kwargs): - dyn_vars.assign(dyn_data) - rand_vars.assign(rand_data) - out = func(*args, **kwargs) - dyn_changes = dyn_vars.dict() - rand_changes = rand_vars.dict() - return out, dyn_changes, rand_changes - - def call(*args): - un_replicated = [k for k, v in dyn_vars.items() - if not isinstance(v.value, (ShardedDeviceArray, JaxprTracer, DynamicJaxprTracer))] - if len(un_replicated): - raise errors.BrainPyError(f'Some variables were not replicated: {un_replicated}.' - f'did you forget to call xx.replicate() on them?') - _args = [] - for i, x in enumerate(args): - if i + 2 in static_broadcasted_argnums: - _args.append(x) - else: - _args.append(jax.tree_map(_device_reshape, [x])[0]) - dyn_data = dyn_vars.dict() - rand_data = rand_vars.dict() - output, dyn_changes, rand_changes = pmapped_func(dyn_data, rand_data, *_args) - dyn_vars.assign(dyn_changes) - rand_vars.assign(rand_changes) - return jax.tree_map(reduce_func, output) - - return change_func_name(name=f_name, f=call) if f_name else call - - -def pmap(func, dyn_vars=None, axis_name=None, in_axes=0, out_axes=0, static_broadcasted_argnums=(), - devices=None, backend=None, axis_size=None, donate_argnums=(), global_arg_shapes=None, - reduce_func=None): - """Parallel compilation for class objects. - - Parallel compile a function or a module to run on multiple devices in parallel. - - Parameters - ---------- - func - axis_name - in_axes - out_axes - static_broadcasted_argnums - devices - backend - axis_size - donate_argnums - global_arg_shapes - - Returns - ------- - - - Examples - -------- - - - """ - - # if isinstance(func, DynamicalSystem): - # if len(func.steps): # DynamicalSystem has step functions - # - # # dynamical variables - # all_vars = (dyn_vars or func.vars().unique()) - # dyn_vars = ArrayCollector() - # rand_vars = ArrayCollector() - # for key, val in all_vars.items(): - # if isinstance(val, RandomState): - # rand_vars[key] = val - # else: - # dyn_vars[key] = val - # - # # reduce function - # if reduce_func is None: - # reduce_func = jnp.concatenate - # - # # static broadcast-ed arguments - # if static_broadcasted_argnums is None: - # static_broadcasted_argnums = () - # elif isinstance(static_broadcasted_argnums, int): - # static_broadcasted_argnums = (static_broadcasted_argnums + 2,) - # elif isinstance(static_broadcasted_argnums, (tuple, list)): - # static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) - # assert isinstance(static_broadcasted_argnums, (tuple, list)) - # - # # jit functions - # for key in func.steps.keys(): - # step = func.steps[key] - # func.steps[key] = _make_pmap(dyn_vars=dyn_vars, - # rand_vars=rand_vars, - # func=step, - # axis_name=axis_name, - # in_axes=in_axes, - # out_axes=out_axes, - # static_broadcasted_argnums=static_broadcasted_argnums, - # devices=devices, - # backend=backend, - # axis_size=axis_size, - # donate_argnums=donate_argnums, - # global_arg_shapes=global_arg_shapes, - # reduce_func=reduce_func, - # f_name=key) - # return func - - if callable(func): - if dyn_vars is not None: - dyn_vars = dyn_vars - elif isinstance(func, BrainPyObject): # BrainPyObject has '__call__()' implementation - dyn_vars = func.vars().unique() - elif hasattr(func, '__self__'): - if isinstance(func.__self__, BrainPyObject): - dyn_vars = func.__self__.vars().unique() - - if dyn_vars is None: - return jax.pmap(func, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes) - else: - # dynamical variables - dyn_vars = ArrayCollector() - rand_vars = ArrayCollector() - for key, val in dyn_vars.items(): - if isinstance(val, RandomState): - rand_vars[key] = val - else: - dyn_vars[key] = val - - # static broadcast-ed arguments - if static_broadcasted_argnums is None: - static_broadcasted_argnums = () - elif isinstance(static_broadcasted_argnums, int): - static_broadcasted_argnums = (static_broadcasted_argnums + 2,) - elif isinstance(static_broadcasted_argnums, (tuple, list)): - static_broadcasted_argnums = tuple(argnum + 2 for argnum in static_broadcasted_argnums) - assert isinstance(static_broadcasted_argnums, (tuple, list)) - - # reduce function - if reduce_func is None: - reduce_func = jnp.concatenate - - # jit function - func.__call__ = _make_pmap(dyn_vars=dyn_vars, - rand_vars=rand_vars, - func=func, - axis_name=axis_name, - in_axes=in_axes, - out_axes=out_axes, - static_broadcasted_argnums=static_broadcasted_argnums, - devices=devices, - backend=backend, - axis_size=axis_size, - donate_argnums=donate_argnums, - global_arg_shapes=global_arg_shapes, - reduce_func=reduce_func) - return func - - else: - raise errors.BrainPyError(f'Only support instance of {BrainPyObject.__name__}, or a callable function, ' - f'but we got {type(func)}.') diff --git a/brainpy/_src/math/object_transform/tools.py b/brainpy/_src/math/object_transform/tools.py index 7b519590..632c6d79 100644 --- a/brainpy/_src/math/object_transform/tools.py +++ b/brainpy/_src/math/object_transform/tools.py @@ -132,19 +132,65 @@ def evaluate_dyn_vars_with_cache( return stack +def _partial_fun2( + fun: Callable, + args: tuple, + kwargs: dict, + static_argnums: Sequence[int] = (), + static_argnames: Sequence[str] = () +): + num_args = len(args) + + # arguments + static_args = dict() + dyn_args = [] + dyn_arg_ids = dict() + static_argnums = list(static_argnums) + dyn_i = 0 + for i in range(num_args): + if i in static_argnums: + static_argnums.remove(i) + static_args[i] = args[i] + else: + dyn_args.append(args[i]) + dyn_arg_ids[i] = dyn_i + dyn_i += 1 + if len(static_argnums) > 0: + raise ValueError(f"Invalid static_argnums: {static_argnums}") + + # keyword arguments + static_kwargs, dyn_kwargs = {}, {} + for k, arg in kwargs.items(): + if k in static_argnames: + static_kwargs[k] = arg + else: + dyn_kwargs[k] = arg + del args, kwargs, static_argnums, static_argnames + + @wraps(fun) + def new_fun(*dynargs, **dynkwargs): + return fun(*[dynargs[dyn_arg_ids[id_]] if id_ in dyn_arg_ids else static_args[id_] for id_ in range(num_args)], + **static_kwargs, + **dynkwargs) + + return new_fun, dyn_args, dyn_kwargs + + def eval_shape( fun: Callable, *args, static_argnums: Sequence[int] = (), static_argnames: Sequence[str] = (), + with_stack: bool = False, **kwargs ): """Compute the shape/dtype of ``fun`` without any FLOPs. Args: fun: The callable function. - *args: - **kwargs: + *args: The positional arguments. + **kwargs: The keyword arguments. + with_stack: Whether evaluate the function within a local variable stack. static_argnums: The static argument indices. static_argnames: The static argument names. @@ -153,21 +199,30 @@ def eval_shape( """ # reorganize the function if len(static_argnums) or len(static_argnames): - f2, args, kwargs = _partial_fun(fun, args, kwargs, - static_argnums=static_argnums, - static_argnames=static_argnames) + f2, args, kwargs = _partial_fun2(fun, args, kwargs, static_argnums=static_argnums, static_argnames=static_argnames) else: - f2, args, kwargs = fun, args, kwargs + f2 = fun # evaluate the function fun_in_eval_shape.append(fun) try: - with jax.ensure_compile_time_eval(): + if with_stack: with VariableStack() as stack: if len(fun_in_eval_shape) > 1: - returns = fun(*args, **kwargs) + returns = f2(*args, **kwargs) else: - returns = jax.eval_shape(fun, *args, **kwargs) + returns = jax.eval_shape(f2, *args, **kwargs) + else: + stack = None + if len(fun_in_eval_shape) > 1: + returns = f2(*args, **kwargs) + else: + returns = jax.eval_shape(f2, *args, **kwargs) finally: fun_in_eval_shape.pop() - return stack, returns + del f2 + if with_stack: + return stack, returns + else: + return returns + diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index 5014da0b..b7babae8 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from typing import Optional, Any, List, Callable, Sequence, Union, Dict, Tuple import jax @@ -190,6 +189,14 @@ def remove_by_id(self, *ids, error_when_absent=False): remove_var_by_id = remove_by_id + @classmethod + def num_of_stack(self): + return len(var_stack_list) + + @classmethod + def is_first_stack(self): + return len(var_stack_list) == 0 + def __enter__(self) -> 'VariableStack': self.collect_values() # recollect the original value of each variable var_stack_list.append(self) @@ -210,42 +217,6 @@ def __add__(self, other: dict): var_stack_list: List[VariableStack] = [] -transform_stack: List[Callable] = [] - - -@contextmanager -def new_transform(transform: Any): - transform_stack.append(transform) - try: - yield - finally: - transform_stack.pop() - - -def outermost_stack(): - if len(var_stack_list): - return var_stack_list[0] - else: - return None - - -def outermost_transform(): - if len(transform_stack): - return transform_stack[0] - else: - return None - - -def current_transform_number(): - return len(transform_stack) - - -def _stack_add_read(var: 'Variable'): - pass - - -def _stack_add_write(var: 'Variable'): - pass @register_pytree_node_class diff --git a/brainpy/_src/tools/functions.py b/brainpy/_src/tools/functions.py new file mode 100644 index 00000000..cbc710db --- /dev/null +++ b/brainpy/_src/tools/functions.py @@ -0,0 +1,192 @@ +import inspect +from functools import partial +from operator import attrgetter +from types import MethodType + +__all__ = [ + 'compose', 'pipe' +] + + +def identity(x): + """ Identity function. Return x + + >>> identity(3) + 3 + """ + return x + + +def instanceproperty(fget=None, fset=None, fdel=None, doc=None, classval=None): + """ Like @property, but returns ``classval`` when used as a class attribute + + >>> class MyClass(object): + ... '''The class docstring''' + ... @instanceproperty(classval=__doc__) + ... def __doc__(self): + ... return 'An object docstring' + ... @instanceproperty + ... def val(self): + ... return 42 + ... + >>> MyClass.__doc__ + 'The class docstring' + >>> MyClass.val is None + True + >>> obj = MyClass() + >>> obj.__doc__ + 'An object docstring' + >>> obj.val + 42 + """ + if fget is None: + return partial(instanceproperty, fset=fset, fdel=fdel, doc=doc, + classval=classval) + return InstanceProperty(fget=fget, fset=fset, fdel=fdel, doc=doc, + classval=classval) + + +class InstanceProperty(property): + """ Like @property, but returns ``classval`` when used as a class attribute + + Should not be used directly. Use ``instanceproperty`` instead. + """ + + def __init__(self, fget=None, fset=None, fdel=None, doc=None, + classval=None): + self.classval = classval + property.__init__(self, fget=fget, fset=fset, fdel=fdel, doc=doc) + + def __get__(self, obj, type=None): + if obj is None: + return self.classval + return property.__get__(self, obj, type) + + def __reduce__(self): + state = (self.fget, self.fset, self.fdel, self.__doc__, self.classval) + return InstanceProperty, state + + +class Compose(object): + """ A composition of functions + + See Also: + compose + """ + __slots__ = 'first', 'funcs' + + def __init__(self, funcs): + funcs = tuple(reversed(funcs)) + self.first = funcs[0] + self.funcs = funcs[1:] + + def __call__(self, *args, **kwargs): + ret = self.first(*args, **kwargs) + for f in self.funcs: + ret = f(ret) + return ret + + def __getstate__(self): + return self.first, self.funcs + + def __setstate__(self, state): + self.first, self.funcs = state + + @instanceproperty(classval=__doc__) + def __doc__(self): + def composed_doc(*fs): + """Generate a docstring for the composition of fs. + """ + if not fs: + # Argument name for the docstring. + return '*args, **kwargs' + + return '{f}({g})'.format(f=fs[0].__name__, g=composed_doc(*fs[1:])) + + try: + return ( + 'lambda *args, **kwargs: ' + + composed_doc(*reversed((self.first,) + self.funcs)) + ) + except AttributeError: + # One of our callables does not have a `__name__`, whatever. + return 'A composition of functions' + + @property + def __name__(self): + try: + return '_of_'.join( + (f.__name__ for f in reversed((self.first,) + self.funcs)) + ) + except AttributeError: + return type(self).__name__ + + def __repr__(self): + return '{.__class__.__name__}{!r}'.format( + self, tuple(reversed((self.first,) + self.funcs))) + + def __eq__(self, other): + if isinstance(other, Compose): + return other.first == self.first and other.funcs == self.funcs + return NotImplemented + + def __ne__(self, other): + equality = self.__eq__(other) + return NotImplemented if equality is NotImplemented else not equality + + def __hash__(self): + return hash(self.first) ^ hash(self.funcs) + + # Mimic the descriptor behavior of python functions. + # i.e. let Compose be called as a method when bound to a class. + # adapted from + # docs.python.org/3/howto/descriptor.html#functions-and-methods + def __get__(self, obj, objtype=None): + return self if obj is None else MethodType(self, obj) + + # introspection with Signature is only possible from py3.3+ + @instanceproperty + def __signature__(self): + base = inspect.signature(self.first) + last = inspect.signature(self.funcs[-1]) + return base.replace(return_annotation=last.return_annotation) + + __wrapped__ = instanceproperty(attrgetter('first')) + + +def compose(*funcs): + """ Compose functions to operate in series. + + Returns a function that applies other functions in sequence. + + Functions are applied from right to left so that + ``compose(f, g, h)(x, y)`` is the same as ``f(g(h(x, y)))``. + + If no arguments are provided, the identity function (f(x) = x) is returned. + + >>> inc = lambda i: i + 1 + >>> compose(str, inc)(3) + '4' + """ + if not funcs: + return identity + if len(funcs) == 1: + return funcs[0] + else: + return Compose(funcs) + + +def pipe(*funcs): + """ Pipe a value through a sequence of functions + + I.e. ``pipe(f, g, h)(data)`` is equivalent to ``h(g(f(data)))`` + + We think of the value as progressing through a pipe of several + transformations, much like pipes in UNIX + + + >>> double = lambda i: 2 * i + >>> pipe(double, str)(3) + '6' + """ + return compose(*reversed(funcs)) diff --git a/brainpy/_src/tools/tests/test_functions.py b/brainpy/_src/tools/tests/test_functions.py new file mode 100644 index 00000000..c285e561 --- /dev/null +++ b/brainpy/_src/tools/tests/test_functions.py @@ -0,0 +1,24 @@ + +import unittest + +import brainpy as bp +import brainpy.math as bm + + +class TestFunction(unittest.TestCase): + def test_compose(self): + f = lambda a: a + 1 + g = lambda a: a * 10 + fun1 = bp.tools.compose(f, g) + fun2 = bp.tools.pipe(g, f) + + arr = bm.random.randn(10) + r1 = fun1(arr) + r2 = fun2(arr) + groundtruth = f(g(arr)) + self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, groundtruth)) + bm.clear_buffer_memory() + + + diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index e4570f6f..3b0c3f51 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -12,7 +12,7 @@ arccos as arccos, acosh as acosh, arccosh as arccosh, - add as add, + # add as add, addcdiv as addcdiv, addcmul as addcmul, angle as angle, diff --git a/brainpy/math/oo_transform.py b/brainpy/math/oo_transform.py index 548a987d..a488e074 100644 --- a/brainpy/math/oo_transform.py +++ b/brainpy/math/oo_transform.py @@ -58,4 +58,6 @@ from brainpy._src.math.object_transform.tools import ( eval_shape as eval_shape, ) - +from brainpy._src.math.object_transform.variables import ( + VariableStack as VariableStack, +) diff --git a/brainpy/tools.py b/brainpy/tools.py index 0f3a4c0e..35e98f6d 100644 --- a/brainpy/tools.py +++ b/brainpy/tools.py @@ -43,6 +43,10 @@ from brainpy._src.tools.install import ( jaxlib_install_info, ) +from brainpy._src.tools.functions import ( + compose as compose, + pipe as pipe, +) diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst index 5c8cba0f..0b78315a 100644 --- a/docs/advanced_tutorials.rst +++ b/docs/advanced_tutorials.rst @@ -3,13 +3,52 @@ Advanced Tutorials This section contains tutorials that illustrate more advanced features of BrainPy. +Advanced Math +------------- .. toctree:: - :maxdepth: 2 + :maxdepth: 1 + + tutorial_advanced/compilation.ipynb + tutorial_advanced/differentiation.ipynb + + +Interoperation +-------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/integrate_flax_into_brainpy.ipynb + tutorial_advanced/integrate_bp_lif_into_flax.ipynb + tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb + + +Brain Dynamics Dedicated Operators +---------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/operator_custom_with_numba.ipynb + tutorial_advanced/operator_custom_with_taichi.ipynb + + +Developer Guides +---------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/contributing.md + + +Others +------ + +.. toctree:: + :maxdepth: 1 + + tutorial_advanced/advanced_lowdim_analysis.ipynb - tutorial_advanced/1_advanced_math.rst - tutorial_advanced/2_interoperation.rst - tutorial_advanced/3_dedicated_operators.rst - tutorial_advanced/4_developer_guides.rst - tutorial_advanced/5_others.rst diff --git a/docs/apis/brainpy.math.oo_transform.rst b/docs/apis/brainpy.math.oo_transform.rst index 754e0d81..9ed9cf46 100644 --- a/docs/apis/brainpy.math.oo_transform.rst +++ b/docs/apis/brainpy.math.oo_transform.rst @@ -77,4 +77,5 @@ Helpers for Object-oriented Transformations :template: classtemplate.rst eval_shape + VariableStack diff --git a/docs/toolboxes.rst b/docs/toolboxes.rst index 11bf5311..cc3a3857 100644 --- a/docs/toolboxes.rst +++ b/docs/toolboxes.rst @@ -1,7 +1,16 @@ BDP Toolboxes ================== + + + This section contains detailed toolboxes BrainPy uses for brain dynamics modeling. + + +Differential Equations +----------------------- + + .. toctree:: :maxdepth: 1 @@ -10,11 +19,34 @@ This section contains detailed toolboxes BrainPy uses for brain dynamics modelin tutorial_toolbox/fde_numerical_solvers tutorial_toolbox/dde_numerical_solvers tutorial_toolbox/joint_equations + + +Toolbox for Modeling +------------------- + +.. toctree:: + :maxdepth: 1 + tutorial_toolbox/synaptic_connections tutorial_toolbox/synaptic_weights + tutorial_toolbox/inputs + + +Toolbox for Training +-------------------- + +.. toctree:: + :maxdepth: 1 + tutorial_toolbox/optimizers - tutorial_toolbox/state_saving_and_loading.ipynb - tutorial_toolbox/state_resetting.ipynb tutorial_toolbox/surrogate_gradient - tutorial_toolbox/inputs + +State Resetting, Saving and Loading +----------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_toolbox/state_saving_and_loading.ipynb + tutorial_toolbox/state_resetting.ipynb \ No newline at end of file diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 7c9a1c87..57d18332 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -3,11 +3,76 @@ BDP Tutorials This section contains tutorials on how to use BrainPy to accomplish model building, simulation, training, and analysis. + +Math Foundation +--------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_math/variables + tutorial_math/control_flows + tutorial_math/Numpy_like_Operations.ipynb + tutorial_math/Dedicated_Operators.ipynb + tutorial_math/einops_in_brainpy.ipynb + + +Model Building with Existing Modules +------------------------------------ + +.. toctree:: + :maxdepth: 1 + + tutorial_building/overview_of_dynamic_model + tutorial_building/build_conductance_neurons_v2.ipynb + tutorial_building/phenon_synapse_models.ipynb + tutorial_building/kinetic_synapse_models.ipynb + tutorial_building/build_network_models + + +Model Building by Customizing New Modules +----------------------------------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_building/customize_neuron_models + tutorial_building/customize_synapse_models + tutorial_building/how_to_customze_a_synapse.ipynb + + +Model Simulation +---------------- + +.. toctree:: + :maxdepth: 1 + + tutorial_simulation/simulation_dsrunner.ipynb + tutorial_simulation/parallel_for_parameter_exploration.ipynb + tutorial_simulation/monitor_per_multiple_steps.ipynb + + +Model Training +-------------- + +This tutorial shows how to train a dynamical system from data or task. + +.. toctree:: + :maxdepth: 1 + + tutorial_training/build_training_models.ipynb + tutorial_training/offline_training.ipynb + tutorial_training/online_training.ipynb + tutorial_training/bp_training.ipynb + tutorial_training/esn_introduction.ipynb + + +Model Analysis +-------------- + .. toctree:: - :maxdepth: 2 + :maxdepth: 1 - tutorial_math/index - tutorial_building/index - tutorial_simulation/index - tutorial_training/index - tutorial_analysis/index + tutorial_analysis/lowdim_analysis + tutorial_analysis/highdim_analysis + tutorial_analysis/decision_making_model From 6c367947da4a108265d346fcffefe1a94d8fc4a4 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 5 Mar 2024 20:27:56 +0800 Subject: [PATCH 12/21] Update installation instruction (#651) * update * update installation docs --- brainpylib-changelog.md | 4 +++- docs/index.rst | 14 +++++++++++--- docs/quickstart/installation.rst | 18 ++++++++++++++++-- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/brainpylib-changelog.md b/brainpylib-changelog.md index 888a9c68..bce0ac13 100644 --- a/brainpylib-changelog.md +++ b/brainpylib-changelog.md @@ -2,18 +2,20 @@ ## Version 0.3.0 +- Support `brainpy>=2.5.0` - Fix bugs on windows platform - remove all customized C++ and CUDA operators - ## Version 0.2.8 +- Support `brainpy>=2.5.0` - Fix bugs that the DLL cannot be loaded correctly when windows does not have a c++ environment, ## ~~Version 0.2.7(YANKED)~~ ## Version 0.2.6 +- Support `brainpy>=2.5.0` - Fix bugs of taichi call function for single result ## Version 0.2.5 diff --git a/docs/index.rst b/docs/index.rst index ada4a873..00271b41 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -19,16 +19,24 @@ Installation pip install -U brainpy[cpu] # windows, linux, macos - .. tab-item:: GPU (CUDA) + .. tab-item:: GPU (CUDA 11.0) .. code-block:: bash - # for CUDA 11.0, Linux only pip install -U brainpy[cuda11] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html - # for CUDA 12.0, Linux only + .. tab-item:: GPU (CUDA 12.0) + + .. code-block:: bash + pip install -U brainpy[cuda12] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html + .. tab-item:: TPU + + .. code-block:: bash + + pip install -U brainpy[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html + For more information, please see `installation `_ section. diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index 6f51bfbd..6931a1e3 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -17,8 +17,7 @@ To install brainpy with minimum requirements (has installed ``jax`` and ``jaxlib .. code-block:: bash - pip install brainpy # for CPU - + pip install brainpy Minimum requirements (with dependencies) @@ -35,6 +34,9 @@ To install brainpy with minimum requirements (only depends on ``jax``), you can pip install brainpy[cuda11_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 11.0 pip install brainpy[cuda12_mini] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html # for CUDA 12.0 + # or + + pip install brainpy[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # for google TPU CPU with all dependencies @@ -61,6 +63,18 @@ To install a GPU-only version of BrainPy, you can run +TPU with all dependencies +------------------------- + +BrainPy supports Google Cloud TPU. To install BrainPy along with appropriate versions of jax, +you can run the following in your cloud TPU VM: + +.. code-block:: bash + + pip install brainpy[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html # for google TPU + + + ``brainpylib`` -------------- From 7511afdc8b7db66ec8116448d8c085efa4c9c69b Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 5 Mar 2024 20:28:26 +0800 Subject: [PATCH 13/21] Fix delay bugs (#650) * remove `reset_state()` * make default attribute None * fix delay bugs * fix `DynamicalSystem.register_local_delay()` bug * add the following decorators for enhancing the ``update()`` capability: - `brainpy.receive_update_output()` - `brainpy.receive_update_input()` - `brainpy.not_receive_update_output()` - `brainpy.not_receive_update_input()` * fix --- brainpy/__init__.py | 5 +- brainpy/_src/delay.py | 42 ++++-- .../_src/dynold/synapses/abstract_models.py | 25 ++-- brainpy/_src/dynold/synapses/base.py | 14 +- brainpy/_src/dynsys.py | 129 ++++++++++++++++-- brainpy/_src/math/object_transform/base.py | 21 ++- brainpy/_src/tests/test_base_classes.py | 50 +++++++ brainpy/_src/tests/test_delay.py | 50 ++++++- brainpy/_src/tests/test_mixin.py | 2 +- 9 files changed, 275 insertions(+), 63 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index a3a1de69..79aa216b 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -61,6 +61,10 @@ Sequential as Sequential, Dynamic as Dynamic, # category Projection as Projection, + receive_update_input, # decorators + receive_update_output, + not_receive_update_input, + not_receive_update_output, ) DynamicalSystemNS = DynamicalSystem Network = DynSysGroup @@ -84,7 +88,6 @@ load_state as load_state, clear_input as clear_input) - # Part: Running # # --------------- # from brainpy._src.runners import (DSRunner as DSRunner) diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index ee0be576..66530a5b 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -28,7 +28,21 @@ ] -delay_identifier = '_*_delay_*_' +delay_identifier = '_*_delay_of_' + + +def _get_delay(delay_time, delay_step): + if delay_time is None: + if delay_step is None: + return None, None + else: + assert isinstance(delay_step, int), '"delay_step" should be an integer.' + delay_time = delay_step * bm.get_dt() + else: + assert delay_step is None, '"delay_step" should be None if "delay_time" is given.' + assert isinstance(delay_time, (int, float)) + delay_step = math.ceil(delay_time / bm.get_dt()) + return delay_time, delay_step class Delay(DynamicalSystem, ParamDesc): @@ -97,13 +111,15 @@ def __init__( def register_entry( self, entry: str, - delay_time: Optional[Union[float, bm.Array, Callable]], + delay_time: Optional[Union[float, bm.Array, Callable]] = None, + delay_step: Optional[int] = None ) -> 'Delay': """Register an entry to access the data. Args: entry: str. The entry to access the delay data. delay_time: The delay time of the entry (can be a float). + delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``. Returns: Return the self. @@ -237,13 +253,15 @@ def __init__( def register_entry( self, entry: str, - delay_time: Optional[Union[int, float]], + delay_time: Optional[Union[int, float]] = None, + delay_step: Optional[int] = None, ) -> 'Delay': """Register an entry to access the data. Args: entry: str. The entry to access the delay data. delay_time: The delay time of the entry (can be a float). + delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``. Returns: Return the self. @@ -258,12 +276,7 @@ def register_entry( assert delay_time.size == 1 and delay_time.ndim == 0 delay_time = delay_time.item() - if delay_time is None: - delay_step = None - delay_time = 0. - else: - assert isinstance(delay_time, (int, float)) - delay_step = math.ceil(delay_time / bm.get_dt()) + _, delay_step = _get_delay(delay_time, delay_step) # delay variable if delay_step is not None: @@ -354,6 +367,8 @@ def update( """Update delay variable with the new data. """ if self.data is not None: + # jax.debug.print('last value == target value {} ', jnp.allclose(latest_value, self.target.value)) + # get the latest target value if latest_value is None: latest_value = self.target.value @@ -361,17 +376,20 @@ def update( # update the delay data at the rotation index if self.method == ROTATE_UPDATE: i = share.load('i') - idx = bm.as_jax((-i - 1) % self.max_length, dtype=jnp.int32) - self.data[idx] = latest_value + idx = bm.as_jax(-i % self.max_length, dtype=jnp.int32) + self.data[jax.lax.stop_gradient(idx)] = latest_value # update the delay data at the first position elif self.method == CONCAT_UPDATE: if self.max_length > 1: latest_value = bm.expand_dims(latest_value, 0) - self.data.value = bm.concat([latest_value, self.data[1:]], axis=0) + self.data.value = bm.concat([latest_value, self.data[:-1]], axis=0) else: self.data[0] = latest_value + else: + raise ValueError(f'Unknown updating method "{self.method}"') + def reset_state(self, batch_size: int = None, **kwargs): """Reset the delay data. """ diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py index c7a902f0..2e214ed2 100644 --- a/brainpy/_src/dynold/synapses/abstract_models.py +++ b/brainpy/_src/dynold/synapses/abstract_models.py @@ -115,12 +115,7 @@ def __init__( self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr') # register delay - self.pre.register_local_delay("spike", self.name, delay_step) - - def reset_state(self, batch_size=None): - self.output.reset_state(batch_size) - if self.stp is not None: - self.stp.reset_state(batch_size) + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) def update(self, pre_spike=None): # pre-synaptic spikes @@ -232,7 +227,6 @@ class Exponential(TwoEndConn): method: str The numerical integration methods. - """ def __init__( @@ -283,17 +277,16 @@ def __init__( else: raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".') - # variables - self.g = self.syn.g - # delay - self.pre.register_local_delay("spike", self.name, delay_step) + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) - def reset_state(self, batch_size=None): - self.syn.reset_state(batch_size) - self.output.reset_state(batch_size) - if self.stp is not None: - self.stp.reset_state(batch_size) + @property + def g(self): + return self.syn.g + + @g.setter + def g(self, value): + self.syn.g = value def update(self, pre_spike=None): # delays diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py index 55bac711..5ceeb4e2 100644 --- a/brainpy/_src/dynold/synapses/base.py +++ b/brainpy/_src/dynold/synapses/base.py @@ -10,8 +10,7 @@ from brainpy._src.dyn.base import NeuDyn from brainpy._src.dynsys import DynamicalSystem from brainpy._src.initialize import parameter -from brainpy._src.mixin import (ParamDesc, JointType, - SupportAutoDelay, BindCondData, ReturnInfo) +from brainpy._src.mixin import (ParamDesc, JointType, SupportAutoDelay, BindCondData, ReturnInfo) from brainpy.errors import UnsupportedError from brainpy.types import ArrayType @@ -47,9 +46,6 @@ def isregistered(self, val: bool): raise ValueError('Must be an instance of bool.') self._registered = val - def reset_state(self, batch_size=None): - pass - def register_master(self, master: SynConn): if not isinstance(master, SynConn): raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}') @@ -296,7 +292,7 @@ def __init__( mode=mode) # delay - self.pre.register_local_delay("spike", self.name, delay_step) + self.pre.register_local_delay("spike", self.name, delay_step=delay_step) # synaptic dynamics self.syn = syn @@ -340,11 +336,5 @@ def g_max(self, v): UserWarning) self.comm.weight = v - def reset_state(self, *args, **kwargs): - self.syn.reset(*args, **kwargs) - self.comm.reset(*args, **kwargs) - self.output.reset(*args, **kwargs) - if self.stp is not None: - self.stp.reset(*args, **kwargs) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index cb086b10..a6fcc16a 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -93,17 +93,41 @@ def __init__( # Attribute for "SupportInputProj" # each instance of "SupportInputProj" should have a "cur_inputs" attribute - self.current_inputs = bm.node_dict() - self.delta_inputs = bm.node_dict() + self._current_inputs: Optional[Dict[str, Callable]] = None + self._delta_inputs: Optional[Dict[str, Callable]] = None # the before- / after-updates used for computing # added after the version of 2.4.3 - self.before_updates: Dict[str, Callable] = bm.node_dict() - self.after_updates: Dict[str, Callable] = bm.node_dict() + self._before_updates: Optional[Dict[str, Callable]] = None + self._after_updates: Optional[Dict[str, Callable]] = None # super initialization super().__init__(name=name) + @property + def current_inputs(self): + if self._current_inputs is None: + self._current_inputs = bm.node_dict() + return self._current_inputs + + @property + def delta_inputs(self): + if self._delta_inputs is None: + self._delta_inputs = bm.node_dict() + return self._delta_inputs + + @property + def before_updates(self): + if self._before_updates is None: + self._before_updates = bm.node_dict() + return self._before_updates + + @property + def after_updates(self): + if self._after_updates is None: + self._after_updates = bm.node_dict() + return self._after_updates + def add_bef_update(self, key: Any, fun: Callable): """Add the before update into this node""" if key in self.before_updates: @@ -220,25 +244,32 @@ def register_local_delay( self, var_name: str, delay_name: str, - delay: Union[numbers.Number, ArrayType] = None, + delay_time: Union[numbers.Number, ArrayType] = None, + delay_step: Union[numbers.Number, ArrayType] = None, ): """Register local relay at the given delay time. Args: var_name: str. The name of the delay target variable. delay_name: str. The name of the current delay data. - delay: The delay time. + delay_time: The delay time. Float. + delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``. """ delay_identifier, init_delay_by_return = _get_delay_tool() delay_identifier = delay_identifier + var_name + # check whether the "var_name" has been registered try: target = getattr(self, var_name) except AttributeError: raise AttributeError(f'This node {self} does not has attribute of "{var_name}".') if not self.has_aft_update(delay_identifier): - self.add_aft_update(delay_identifier, init_delay_by_return(target)) + # add a model to receive the return of the target model + # moreover, the model should not receive the return of the update function + model = not_receive_update_output(init_delay_by_return(target)) + # register the model + self.add_aft_update(delay_identifier, model) delay_cls = self.get_aft_update(delay_identifier) - delay_cls.register_entry(delay_name, delay) + delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step) def get_local_delay(self, var_name, delay_name): """Get the delay at the given identifier (`name`). @@ -381,14 +412,20 @@ def __call__(self, *args, **kwargs): # ``before_updates`` for model in self.before_updates.values(): - model() + if hasattr(model, '_receive_update_input'): + model(*args, **kwargs) + else: + model() # update the model self ret = self.update(*args, **kwargs) # ``after_updates`` for model in self.after_updates.values(): - model(ret) + if hasattr(model, '_not_receive_update_output'): + model() + else: + model(ret) return ret def __rrshift__(self, other): @@ -832,3 +869,75 @@ def _slice_to_num(slice_: slice, length: int): start += step num += 1 return num + + +def receive_update_output(cls: object): + """ + The decorator to mark the object (as the after updates) to receive the output of the update function. + + That is, the `aft_update` will receive the return of the update function:: + + ret = model.update(*args, **kwargs) + for fun in model.aft_updates: + fun(ret) + + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + if hasattr(cls, '_not_receive_update_output'): + delattr(cls, '_not_receive_update_output') + return cls + + +def not_receive_update_output(cls: object): + """ + The decorator to mark the object (as the after updates) to not receive the output of the update function. + + That is, the `aft_update` will not receive the return of the update function:: + + ret = model.update(*args, **kwargs) + for fun in model.aft_updates: + fun() + + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + cls._not_receive_update_output = True + return cls + + +def receive_update_input(cls: object): + """ + The decorator to mark the object (as the before updates) to receive the input of the update function. + + That is, the `bef_update` will receive the input of the update function:: + + + for fun in model.bef_updates: + fun(*args, **kwargs) + model.update(*args, **kwargs) + + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + cls._receive_update_input = True + return cls + + +def not_receive_update_input(cls: object): + """ + The decorator to mark the object (as the before updates) to not receive the input of the update function. + + That is, the `bef_update` will not receive the input of the update function:: + + for fun in model.bef_updates: + fun() + model.update() + + """ + # assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.' + if hasattr(cls, '_receive_update_input'): + delattr(cls, '_receive_update_input') + return cls + + + + + diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 936f6238..de64f94e 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -14,6 +14,7 @@ import numpy as np from jax.tree_util import register_pytree_node_class +from brainpy._src.math import defaults from brainpy._src.math.modes import Mode from brainpy._src.math.ndarray import (Array, ) from brainpy._src.math.object_transform.collectors import (ArrayCollector, Collector) @@ -22,13 +23,11 @@ from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, VarList, VarDict) from brainpy._src.math.sharding import BATCH_AXIS -from brainpy._src.math import defaults variable_ = None StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) registered = set() - __all__ = [ 'BrainPyObject', 'Base', 'FunAsObject', 'ObjectTransform', @@ -103,11 +102,23 @@ def __init__(self, name=None): # Used to wrap the implicit variables # which cannot be accessed by self.xxx - self.implicit_vars: ArrayCollector = ArrayCollector() + self._implicit_vars: Optional[ArrayCollector] = None # Used to wrap the implicit children nodes # which cannot be accessed by self.xxx - self.implicit_nodes: Collector = Collector() + self._implicit_nodes: Optional[Collector] = None + + @property + def implicit_vars(self): + if self._implicit_vars is None: + self._implicit_vars = ArrayCollector() + return self._implicit_vars + + @property + def implicit_nodes(self): + if self._implicit_nodes is None: + self._implicit_nodes = Collector() + return self._implicit_nodes def setattr(self, key: str, value: Any) -> None: super().__setattr__(key, value) @@ -225,7 +236,7 @@ def tree_flatten(self): static_values = [] for k, v in self.__dict__.items(): if isinstance(v, (BrainPyObject, Variable, NodeList, NodeDict, VarList, VarDict)): - # if isinstance(v, (BrainPyObject, Variable)): + # if isinstance(v, (BrainPyObject, Variable)): dynamic_names.append(k) dynamic_values.append(v) else: diff --git a/brainpy/_src/tests/test_base_classes.py b/brainpy/_src/tests/test_base_classes.py index 9c095a30..3534f0a4 100644 --- a/brainpy/_src/tests/test_base_classes.py +++ b/brainpy/_src/tests/test_base_classes.py @@ -3,6 +3,7 @@ import unittest import brainpy as bp +import brainpy.math as bm class TestDynamicalSystem(unittest.TestCase): @@ -17,4 +18,53 @@ def test_delay(self): runner = bp.DSRunner(net,) runner.run(10.) + bm.clear_buffer_memory() + + def test_receive_update_output(self): + def aft_update(inp): + assert inp is not None + + hh = bp.dyn.HH(1) + hh.add_aft_update('aft_update', aft_update) + bp.share.save(i=0, t=0.) + hh(1.) + + bm.clear_buffer_memory() + + def test_do_not_receive_update_output(self): + def aft_update(): + pass + + hh = bp.dyn.HH(1) + hh.add_aft_update('aft_update', bp.not_receive_update_output(aft_update)) + bp.share.save(i=0, t=0.) + hh(1.) + + bm.clear_buffer_memory() + + def test_not_receive_update_input(self): + def bef_update(): + pass + + hh = bp.dyn.HH(1) + hh.add_bef_update('bef_update', bef_update) + bp.share.save(i=0, t=0.) + hh(1.) + + bm.clear_buffer_memory() + + def test_receive_update_input(self): + def bef_update(inp): + assert inp is not None + + hh = bp.dyn.HH(1) + hh.add_bef_update('bef_update', bp.receive_update_input(bef_update)) + bp.share.save(i=0, t=0.) + hh(1.) + + bm.clear_buffer_memory() + + + + diff --git a/brainpy/_src/tests/test_delay.py b/brainpy/_src/tests/test_delay.py index 20d49937..b7bd44ea 100644 --- a/brainpy/_src/tests/test_delay.py +++ b/brainpy/_src/tests/test_delay.py @@ -1,13 +1,15 @@ +import unittest + +import jax.numpy as jnp import brainpy as bp -import unittest class TestVarDelay(unittest.TestCase): def test_delay1(self): bp.math.random.seed() a = bp.math.Variable((10, 20)) - delay = bp.VarDelay(a,) + delay = bp.VarDelay(a, ) delay.register_entry('a', 1.) delay.register_entry('b', 2.) delay.register_entry('c', None) @@ -15,8 +17,44 @@ def test_delay1(self): delay.register_entry('c', 10.) bp.math.clear_buffer_memory() + def test_rotation_delay(self): + a = bp.math.Variable((1,)) + rotation_delay = bp.VarDelay(a) + t0 = 0. + t1, n1 = 1., 10 + t2, n2 = 2., 20 + + rotation_delay.register_entry('a', t0) + rotation_delay.register_entry('b', t1) + rotation_delay.register_entry('c', t2) + + print() + for i in range(100): + bp.share.save(i=i) + a.value = jnp.ones((1,)) * i + rotation_delay() + self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) + self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1 + 1, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2 + 1, 0.))) + bp.math.clear_buffer_memory() - - - - + def test_concat_delay(self): + a = bp.math.Variable((1,)) + rotation_delay = bp.VarDelay(a, method='concat') + t0 = 0. + t1, n1 = 1., 10 + t2, n2 = 2., 20 + + rotation_delay.register_entry('a', t0) + rotation_delay.register_entry('b', t1) + rotation_delay.register_entry('c', t2) + + print() + for i in range(100): + bp.share.save(i=i) + a.value = jnp.ones((1,)) * i + rotation_delay() + self.assertTrue(jnp.allclose(rotation_delay.at('a'), jnp.ones((1,)) * i)) + self.assertTrue(jnp.allclose(rotation_delay.at('b'), jnp.maximum(jnp.ones((1,)) * i - n1 + 1, 0.))) + self.assertTrue(jnp.allclose(rotation_delay.at('c'), jnp.maximum(jnp.ones((1,)) * i - n2 + 1, 0.))) + bp.math.clear_buffer_memory() diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py index 962b76cb..e864fd64 100644 --- a/brainpy/_src/tests/test_mixin.py +++ b/brainpy/_src/tests/test_mixin.py @@ -42,7 +42,7 @@ class TestDelayRegister(unittest.TestCase): def test2(self): bp.share.save(i=0) lif = bp.dyn.Lif(10) - lif.register_local_delay('spike', 'a', 10.) + lif.register_local_delay('spike', 'a', delay_time=10.) data = lif.get_local_delay('spike', 'a') self.assertTrue(bm.allclose(data, bm.zeros(10))) From 23b5ab957b642b92ba9914a1941b19d780174419 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 8 Mar 2024 15:25:30 +0800 Subject: [PATCH 14/21] update doc (#652) --- docs/index.rst | 1 + docs/quickstart/installation.rst | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 00271b41..d4d4f272 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -17,6 +17,7 @@ Installation .. code-block:: bash + # python 3.9-3.11 pip install -U brainpy[cpu] # windows, linux, macos .. tab-item:: GPU (CUDA 11.0) diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst index 6931a1e3..395bf627 100644 --- a/docs/quickstart/installation.rst +++ b/docs/quickstart/installation.rst @@ -10,8 +10,8 @@ Installation Linux, and MacOS. It only relies on Python libraries. -Minimum requirements (without dependencies) -------------------------------------------- +Without dependencies +-------------------- To install brainpy with minimum requirements (has installed ``jax`` and ``jaxlib`` before), you can use: @@ -23,6 +23,11 @@ To install brainpy with minimum requirements (has installed ``jax`` and ``jaxlib Minimum requirements (with dependencies) ---------------------------------------- +.. note:: + + Full features of brainpy currently is only available on Python 3.9 - 3.11. + + To install brainpy with minimum requirements (only depends on ``jax``), you can use: .. code-block:: bash From 38662033e7523404f1f0c3b449229d59f9d3a7dc Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Wed, 20 Mar 2024 17:10:01 +0800 Subject: [PATCH 15/21] [math] Add new customize operators with `cupy` (#653) * Update * Implement cupy based customized operators and Need to be tested * Fix bugs * Update base.py * Update dependency_check.py * Format codes * Implement customized op with cupy `JIT Kernel` * Update docs * small update --------- Co-authored-by: Chaoming Wang --- brainpy/_src/dependency_check.py | 50 ++++ brainpy/_src/math/op_register/base.py | 53 ++-- brainpy/_src/math/op_register/cupy_based.py | 279 ++++++++++++++++++ .../math/op_register/tests/test_cupy_based.py | 79 +++++ .../op_register/tests/test_taichi_based.py | 19 +- .../operator_custom_with_cupy.ipynb | 174 +++++++++++ .../operator_custom_with_taichi.ipynb | 19 +- 7 files changed, 618 insertions(+), 55 deletions(-) create mode 100644 brainpy/_src/math/op_register/cupy_based.py create mode 100644 brainpy/_src/math/op_register/tests/test_cupy_based.py create mode 100644 docs/tutorial_advanced/operator_custom_with_cupy.ipynb diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index b8bd6e99..1e106062 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -8,6 +8,9 @@ 'raise_taichi_not_found', 'import_numba', 'raise_numba_not_found', + 'import_cupy', + 'import_cupy_jit', + 'raise_cupy_not_found', 'import_brainpylib_cpu_ops', 'import_brainpylib_gpu_ops', ] @@ -17,6 +20,8 @@ numba = None taichi = None +cupy = None +cupy_jit = None brainpylib_cpu_ops = None brainpylib_gpu_ops = None @@ -25,6 +30,9 @@ '> pip install taichi==1.7.0') numba_install_info = ('We need numba. Please install numba by pip . \n' '> pip install numba') +cupy_install_info = ('We need cupy. Please install cupy by pip . \n' + 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n' + 'For CUDA v12.x > pip install cupy-cuda12x\n') os.environ["TI_LOG_LEVEL"] = "error" @@ -81,6 +89,48 @@ def raise_numba_not_found(): raise ModuleNotFoundError(numba_install_info) +def import_cupy(error_if_not_found=True): + """ + Internal API to import cupy. + + If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global cupy + if cupy is None: + try: + import cupy as cupy + except ModuleNotFoundError: + if error_if_not_found: + raise_cupy_not_found() + else: + return None + return cupy + + +def import_cupy_jit(error_if_not_found=True): + """ + Internal API to import cupy. + + If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global cupy_jit + if cupy_jit is None: + try: + from cupyx import jit as cupy_jit + except ModuleNotFoundError: + if error_if_not_found: + raise_cupy_not_found() + else: + return None + return cupy_jit + + +def raise_cupy_not_found(): + raise ModuleNotFoundError(cupy_install_info) + + def is_brainpylib_gpu_installed(): return False if brainpylib_gpu_ops is None else True diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index ca070a19..5af5a7e3 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -1,11 +1,11 @@ from functools import partial -from typing import Callable, Sequence, Tuple, Protocol, Optional +from typing import Callable, Sequence, Tuple, Protocol, Optional, Union import jax import numpy as np from jax.interpreters import xla, batching, ad, mlir -from brainpy._src.dependency_check import import_numba +from brainpy._src.dependency_check import import_numba, import_cupy_jit from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject @@ -13,14 +13,19 @@ from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) + from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) else: from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule) + from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp numba = import_numba(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) __all__ = [ 'XLACustomOp', @@ -41,34 +46,10 @@ def dtype(self) -> np.dtype: class XLACustomOp(BrainPyObject): """Creating a XLA custom call operator. - >>> import numba as nb - >>> import taichi as ti - >>> import numpy as np - >>> import jax - >>> - >>> @nb.njit - >>> def numba_cpu_fun(a, b, out_a, out_b): - >>> out_a[:] = a - >>> out_b[:] = b - >>> - >>> @ti.kernel - >>> def taichi_gpu_fun(a, b, out_a, out_b): - >>> for i in range(a.size): - >>> out_a[i] = a[i] - >>> for i in range(b.size): - >>> out_b[i] = b[i] - >>> - >>> # option 1 - >>> prim = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun) - >>> a2, b2 = prim(np.random.random(1000), np.random.random(1000), - >>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32), - >>> jax.ShapeDtypeStruct(1000, dtype=np.float32)]) - >>> - >>> # option 2 - >>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun, - >>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype), - >>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)]) - >>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000)) + For more information, please refer to the tutorials above: + Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html + Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html + CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html Args: cpu_kernel: Callable. The function defines the computation on CPU backend. @@ -83,7 +64,7 @@ class XLACustomOp(BrainPyObject): def __init__( self, cpu_kernel: Callable = None, - gpu_kernel: Callable = None, + gpu_kernel: Union[Callable, str] = None, batching_translation: Callable = None, jvp_translation: Callable = None, transpose_translation: Callable = None, @@ -125,11 +106,17 @@ def __init__( gpu_checked = False if gpu_kernel is None: gpu_checked = True - if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi + elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule + register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel) + gpu_checked = True + elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel + register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel) + gpu_checked = True + elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True if not gpu_checked: - raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}') + raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/_src/math/op_register/cupy_based.py b/brainpy/_src/math/op_register/cupy_based.py new file mode 100644 index 00000000..ad6befec --- /dev/null +++ b/brainpy/_src/math/op_register/cupy_based.py @@ -0,0 +1,279 @@ +from functools import partial, reduce +from typing import List, Tuple + +import jax +import numpy as np +from jax.interpreters import xla, mlir +from jax.lib import xla_client +from jaxlib.hlo_helpers import custom_call + +from brainpy._src.dependency_check import (import_cupy, + import_cupy_jit, + import_brainpylib_gpu_ops) +from brainpy._src.math.op_register.utils import _shape_to_layout +from brainpy.errors import PackageMissingError + +cp = import_cupy(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) + +# convert type to number +type_number_map = { + int: 0, + float: 1, + bool: 2, + np.dtype('int32'): 0, + np.dtype('float32'): 1, + np.dtype('bool'): 2, + np.dtype('uint8'): 3, + np.dtype('uint16'): 4, + np.dtype('uint32'): 5, + np.dtype('uint64'): 6, + np.dtype('int8'): 7, + np.dtype('int16'): 8, + np.dtype('int64'): 9, + np.dtype('float16'): 10, + np.dtype('float64'): 11, +} + + +def _preprocess_kernel_call_gpu( + grid: Tuple[int], + block: Tuple[int], + func_ptr: int, + shared_mem: int, + *ins, + outs: List[jax.ShapeDtypeStruct], +): + grid = (grid + (1, 1))[:3] + block = (block + (1, 1))[:3] + in_num = len(ins) + out_num = len(outs) + in_out_num = [in_num, out_num] + + out_type_list = [0] * out_num + out_elem_count_list = [0] * out_num + + for i, value in enumerate(outs): + out_type_list[i] = type_number_map[value.dtype] + out_elem_count_list[i] = reduce(lambda x, y: x * y, value.shape) + + grid = ",".join(str(i) for i in grid) + block = ",".join(str(i) for i in block) + in_out_num_str = ",".join(str(i) for i in in_out_num) + out_type_list_str = ",".join(str(i) for i in out_type_list) + out_elem_count_list_str = ",".join(str(i) for i in out_elem_count_list) + + opaque = (bytes(str(func_ptr), encoding='utf-8') + b';' + + bytes(str(shared_mem), encoding='utf-8') + b';' + + bytes(in_out_num_str, encoding='utf-8') + b';' + + bytes(grid, encoding='utf-8') + b';' + + bytes(block, encoding='utf-8') + b';' + + bytes(out_type_list_str, encoding='utf-8') + b';' + + bytes(out_elem_count_list_str, encoding='utf-8') + b';') + return opaque + + +def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + # THE KEY: + # - using the kernel pointer at "kernel.kernel.ptr" + opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) + + # create custom call + return xla_client.ops.CustomCallWithLayout( + c, + b'cupy_kernel_call_gpu', + operands=ins, + operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), + shape_with_layout=xla_client.Shape.tuple_shape( + [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) + for value in kwargs['outs']] + ), + opaque=opaque, + ) + + +def register_cupy_raw_module_xla_gpu_translation_rule(primitive, gpu_kernel): + xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_raw_module_xla_gpu_translation_rule, gpu_kernel) + + +def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) + + input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] + result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] + output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out] + + return custom_call( + call_target_name='cupy_kernel_call_gpu', + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + backend_config=opaque, + has_side_effect=False, + ).results + + +def register_cupy_raw_module_mlir_gpu_translation_rule(primitive, gpu_kernel): + if cp is None: + raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule') + + rule = partial(_cupy_raw_module_mlir_gpu_translation_rule, gpu_kernel) + mlir.register_lowering(primitive, rule, platform='gpu') + + +def _to_cupy_array_or_scalar(dtype, ndim): + # THE KEY + # - using the cupy jit compiler to get the type + if ndim != 0: + t = cp_jit._cuda_types.CArray(dtype=dtype, + ndim=ndim, + is_c_contiguous=True, + index_32_bits=True) + else: + t = cp_jit._cuda_types.Scalar(dtype=dtype) + return t + + +def _compile_kernel_xla(kernel, in_types): + # THE KEY + # - get the kernel function from the cache + device_id = cp.cuda.get_device_id() + kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None)) + + if kern is None: + # THE KEY: + # - compile the kernel function + result = kernel._cached_codes.get(in_types) + if result is None: + result = cp_jit._compile.transpile( + kernel._func, + ['extern "C"', '__global__'], + 'cuda', + in_types, + cp_jit._cuda_types.void, + ) + kernel._cached_codes[in_types] = result + fname = result.func_name + enable_cg = result.enable_cooperative_groups + options = result.options + backend = result.backend + if backend == 'nvcc': + options += ('-DCUPY_JIT_NVCC',) + jitify = result.jitify + module = cp._core.core.compile_with_cache( + source=result.code, + options=options, + backend=backend, + jitify=jitify, + ) + kern = module.get_function(fname) + kernel._cache[(in_types, device_id)] = (kern, enable_cg) + + return kern + + +def get_jit_kernel_xla(kernel, c, *ins, outs): + # get the input types + in_types = [] + for x in ins: + x = c.get_shape(x) + in_types.append(_to_cupy_array_or_scalar(x.element_type(), len(x.dimensions()))) + for x in outs: + in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) + in_types = tuple(in_types) + # compile the kernel + return _compile_kernel_xla(kernel, in_types) + + +def get_jit_kernel_mlir(kernel, c): + # get the input types + in_types = [] + for x in c.avals_in: + in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) + for x in c.avals_out: + in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) + in_types = tuple(in_types) + # compile the kernel + return _compile_kernel_xla(kernel, in_types) + + +def _cupy_jit_kernel_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): + kernel_func = get_jit_kernel_xla(kernel, c, *ins, outs=kwargs['outs']) + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs']) + + # create custom call + return xla_client.ops.CustomCallWithLayout( + c, + b'cupy_kernel_call_gpu', + operands=ins, + operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), + shape_with_layout=xla_client.Shape.tuple_shape( + [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) + for value in kwargs['outs']] + ), + opaque=opaque, + ) + + +def register_cupy_jit_kernel_xla_gpu_translation_rule(primitive, gpu_kernel): + xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_jit_kernel_xla_gpu_translation_rule, gpu_kernel) + + +def _cupy_jit_kernel_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): + kernel_func = get_jit_kernel_mlir(kernel, c) + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs']) + + input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] + result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] + output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out] + + return custom_call( + call_target_name='cupy_kernel_call_gpu', + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + backend_config=opaque, + has_side_effect=False, + ).results + + +def register_cupy_jit_kernel_mlir_gpu_translation_rule(primitive, gpu_kernel): + if cp is None: + raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule') + + rule = partial(_cupy_jit_kernel_mlir_gpu_translation_rule, gpu_kernel) + mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/tests/test_cupy_based.py b/brainpy/_src/math/op_register/tests/test_cupy_based.py new file mode 100644 index 00000000..772b6160 --- /dev/null +++ b/brainpy/_src/math/op_register/tests/test_cupy_based.py @@ -0,0 +1,79 @@ +import jax +import pytest + +import brainpy.math as bm +from brainpy._src.dependency_check import import_cupy, import_cupy_jit, import_taichi + +cp = import_cupy(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) +ti = import_taichi(error_if_not_found=False) +if cp is None or ti is None: + pytest.skip('no cupy or taichi', allow_module_level=True) +bm.set_platform('cpu') + + +def test_cupy_based(): + bm.op_register.clear_taichi_aot_caches() + # Raw Module + + @ti.kernel + def simpleAdd(x1: ti.types.ndarray(ndim=2), + x2: ti.types.ndarray(ndim=2), + n: ti.types.ndarray(ndim=0), + y: ti.types.ndarray(ndim=2)): + for i, j in y: + y[i, j] = x1[i, j] + x2[i, j] + + source_code = r''' + extern "C"{ + + __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y) + { + unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < N) + { + y[tid] = x1[tid] + x2[tid]; + } + } + } + ''' + N = 10 + x1 = bm.ones((N, N)) + x2 = bm.ones((N, N)) + + mod = cp.RawModule(code=source_code) + kernel = mod.get_function('kernel') + + prim1 = bm.XLACustomOp(cpu_kernel=simpleAdd, gpu_kernel=kernel) + + y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0] + + print(y) + assert bm.allclose(y, x1 + x2) + + # JIT Kernel + @ti.kernel + def elementwise_copy_taichi(x: ti.types.ndarray(ndim=1), + size: ti.types.ndarray(ndim=1), + y: ti.types.ndarray(ndim=1)): + for i in y: + y[i] = x[i] + + @cp_jit.rawkernel() + def elementwise_copy(x, size, y): + tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x + ntid = cp_jit.gridDim.x * cp_jit.blockDim.x + for i in range(tid, size, ntid): + y[i] = x[i] + + size = 100 + x = bm.ones((size,)) + + prim2 = bm.XLACustomOp(cpu_kernel=elementwise_copy_taichi, gpu_kernel=elementwise_copy) + + y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0] + + print(y) + assert bm.allclose(y, x) + +# test_cupy_based() diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 199dce98..ea6dcadc 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -11,10 +11,9 @@ bm.set_platform('cpu') - @ti.func -def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: - return weight[0] +def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32: + return weight[None] @ti.func @@ -25,7 +24,7 @@ def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f @ti.kernel def event_ell_cpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=0), out: ti.types.ndarray(ndim=1)): weight_val = get_weight(weight) num_rows, num_cols = indices.shape @@ -35,11 +34,10 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) - @ti.kernel def event_ell_gpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=0), out: ti.types.ndarray(ndim=1)): weight_val = get_weight(weight) num_rows, num_cols = indices.shape @@ -48,21 +46,18 @@ def event_ell_gpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) - prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu) def test_taichi_op_register(): s = 1000 - indices = bm.random.randint(0, s, (s, 100)) + indices = bm.random.randint(0, s, (s, 1000)) vector = bm.random.rand(s) < 0.1 - weight = bm.array([1.0]) - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) print(out) - bm.clear_buffer_memory() # test_taichi_op_register() diff --git a/docs/tutorial_advanced/operator_custom_with_cupy.ipynb b/docs/tutorial_advanced/operator_custom_with_cupy.ipynb new file mode 100644 index 00000000..0b4bf241 --- /dev/null +++ b/docs/tutorial_advanced/operator_custom_with_cupy.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CPU and GPU Operator Customization with CuPy\n", + "\n", + "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)\n", + "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This functionality is only available for ``brainpylib>=0.3.1``. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## English Version" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Although we can now use the flexible taichi custom operator approach, taichi on cuda does not have more fine-grained control or optimization for some scenarios. So for such scenarios, we can use cupy's \n", + "- `RawModule`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n", + "- `jit.rawkernel`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n", + "\n", + "to compile and run CUDA native code directly as strings or cupy JIT function in real time for finer grained control.\n", + "\n", + "Start by importing the relevant Python package." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import brainpy.math as bm\n", + "\n", + "import jax\n", + "import cupy as cp\n", + "from cupyx import jit\n", + "\n", + "bm.set_platform('gpu')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuPy RawModule\n", + "\n", + "For dealing a large raw CUDA source or loading an existing CUDA binary, the RawModule class can be more handy. It can be initialized either by a CUDA source code. The needed kernels can then be retrieved by calling the get_function() method, which returns a RawKernel instance that can be invoked as discussed above.\n", + "\n", + "Be aware that the order of parameters in the kernel function you want to call should **keep outputs at the end of the parameter list**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source_code = r'''\n", + " extern \"C\"{\n", + "\n", + " __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)\n", + " {\n", + " unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;\n", + " if (tid < N)\n", + " {\n", + " y[tid] = x1[tid] + x2[tid];\n", + " }\n", + " }\n", + " }\n", + "'''\n", + "mod = cp.RawModule(code=source_code)\n", + "kernel = mod.get_function('kernel')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After define the `RawModule` and get the kernel function. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare inputs\n", + "N = 10\n", + "x1 = bm.ones((N, N))\n", + "x2 = bm.ones((N, N))\n", + "\n", + "# register the kernel as a custom op\n", + "prim1 = bm.XLACustomOp(gpu_kernel=kernel)\n", + "\n", + "# call the custom op\n", + "y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuPy JIT RawKernel\n", + "The cupyx.jit.rawkernel decorator can create raw CUDA kernels from Python functions.\n", + "\n", + "In this section, a Python function wrapped with the decorator is called a target function.\n", + "\n", + "Here is a short example for how to write a cupyx.jit.rawkernel to copy the values from x to y using a grid-stride loop:\n", + "\n", + "Launching a CUDA kernel on a GPU with pre-determined grid/block sizes requires basic understanding in the CUDA Programming Model. And the compilation will be deferred until the first function call. CuPy’s JIT compiler infers the types of arguments at the call time, and will cache the compiled kernels for speeding up any subsequent calls." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jit.rawkernel()\n", + "def elementwise_copy(x, size, y):\n", + " tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x\n", + " ntid = jit.gridDim.x * jit.blockDim.x\n", + " for i in range(tid, size, ntid):\n", + " y[i] = x[i]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After define the `jit.rawkernel`. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare inputs\n", + "size = 100\n", + "x = bm.ones((size,))\n", + "\n", + "# register the kernel as a custom op\n", + "prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)\n", + "\n", + "# call the custom op\n", + "y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb index 3c2667df..4b86a426 100644 --- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb @@ -99,8 +99,8 @@ "\n", "```python\n", "@ti.func\n", - "def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:\n", - " return weight[0]\n", + "def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:\n", + " return weight[None]\n", "\n", "@ti.func\n", "def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):\n", @@ -109,7 +109,7 @@ "@ti.kernel\n", "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1),\n", - " weight: ti.types.ndarray(ndim=1),\n", + " weight: ti.types.ndarray(ndim=0),\n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -207,8 +207,8 @@ "bm.set_platform('cpu')\n", "\n", "@ti.func\n", - "def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:\n", - " return weight[0]\n", + "def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:\n", + " return weight[None]\n", "\n", "\n", "@ti.func\n", @@ -219,7 +219,7 @@ "@ti.kernel\n", "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1),\n", - " weight: ti.types.ndarray(ndim=1),\n", + " weight: ti.types.ndarray(ndim=0),\n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -232,7 +232,7 @@ "@ti.kernel\n", "def event_ell_gpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1), \n", - " weight: ti.types.ndarray(ndim=1), \n", + " weight: ti.types.ndarray(ndim=0), \n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -248,11 +248,10 @@ " s = 1000\n", " indices = bm.random.randint(0, s, (s, 1000))\n", " vector = bm.random.rand(s) < 0.1\n", - " weight = bm.array([1.0])\n", "\n", - " out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", + " out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", "\n", - " out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", + " out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", "\n", " print(out)\n", "\n", From e0ee14248163d567380c5376cf9816c03c44ee02 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Fri, 22 Mar 2024 13:18:35 +0800 Subject: [PATCH 16/21] [math] Fix taichi custom operator on gpu backend (#655) --- .../_src/math/op_register/taichi_aot_based.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 2a8cb3b6..858f338b 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -316,11 +316,11 @@ def _preprocess_kernel_call_cpu( def _preprocess_kernel_call_gpu( source_md5_encode: str, - ins: dict, - outs: dict, + ins: Sequence, + outs: Sequence, ) -> bytes: - if len(ins) + len(outs) > 8: - raise ValueError('The number of ins and outs must be less than 8!') + # if len(ins) + len(outs) > 8: + # raise ValueError('The number of ins and outs must be less than 8!') kernel_path = os.path.join(kernels_aot_path, source_md5_encode) # other args @@ -331,18 +331,18 @@ def _preprocess_kernel_call_gpu( in_out_elem_count_list = [0] * param_total_num in_out_shape_list = [0] * param_total_num * 8 - for i, value in enumerate(ins.values()): - in_out_type_list[i] = type_number_map[value[0]] - in_out_dim_count_list[i] = len(value[1]) - in_out_elem_count_list[i] = reduce(lambda x, y: x * y, value[1]) - for j, dim in enumerate(value[1]): + for i, value in enumerate(ins): + in_out_type_list[i] = type_number_map[value.dtype] + in_out_dim_count_list[i] = value.ndim + in_out_elem_count_list[i] = value.size + for j, dim in enumerate(value.shape): in_out_shape_list[i * 8 + j] = dim - for i, value in enumerate(outs.values()): - in_out_type_list[i + len(ins)] = type_number_map[value[0]] - in_out_dim_count_list[i + len(ins)] = len(value[1]) - in_out_elem_count_list[i + len(ins)] = reduce(lambda x, y: x * y, value[1]) - for j, dim in enumerate(value[1]): + for i, value in enumerate(outs): + in_out_type_list[i + len(ins)] = type_number_map[value.dtype] + in_out_dim_count_list[i + len(ins)] = value.ndim + in_out_elem_count_list[i + len(ins)] = value.size + for j, dim in enumerate(value.shape): in_out_shape_list[(i + len(ins)) * 8 + j] = dim # covert to string @@ -407,7 +407,7 @@ def _compile_kernel(abs_ins, kernel, platform: str, **kwargs): # returns if platform in ['gpu', 'cuda']: import_brainpylib_gpu_ops() - opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict) + opaque = _preprocess_kernel_call_gpu(source_md5_encode, abs_ins, abs_outs) return opaque elif platform == 'cpu': import_brainpylib_cpu_ops() From 05394a27a893210b236d18cc548cc1124c3a06ff Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 23 Mar 2024 13:24:02 +0800 Subject: [PATCH 17/21] dtype checking during exponential euler method --- brainpy/_src/integrators/ode/exponential.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index e44e324e..ec0e1070 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -106,6 +106,9 @@ """ from functools import wraps + +import jax.numpy as jnp + from brainpy import errors from brainpy._src import math as bm from brainpy._src.integrators import constants as C, utils, joint_eq @@ -356,6 +359,9 @@ def _build_integrator(self, eq): # integration function def integral(*args, **kwargs): assert len(args) > 0 + if args[0].dtype not in [jnp.float32, jnp.float64, jnp.float16, jnp.bfloat16]: + raise ValueError('The input data type should be float32, float64, float16, or bfloat16 when using Exponential Euler method.' + f'But we got {args[0].dtype}.') dt = kwargs.pop(C.DT, self.dt) linear, derivative = value_and_grad(*args, **kwargs) phi = bm.exprel(dt * linear) From 4e6c7b6dda68d8654b27edea24ed8c1d3bdcd831 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 23 Mar 2024 13:29:18 +0800 Subject: [PATCH 18/21] update cupy operator custom doc (#656) --- docs/advanced_tutorials.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst index 0b78315a..2ddf3906 100644 --- a/docs/advanced_tutorials.rst +++ b/docs/advanced_tutorials.rst @@ -32,6 +32,7 @@ Brain Dynamics Dedicated Operators tutorial_advanced/operator_custom_with_numba.ipynb tutorial_advanced/operator_custom_with_taichi.ipynb + tutorial_advanced/operator_custom_with_cupy.ipynb Developer Guides From 1fcdc8611165ebe819401042a312603dcb116263 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 24 Mar 2024 10:23:32 +0800 Subject: [PATCH 19/21] version 2.6.0 (#657) --- brainpy/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 79aa216b..837efaf1 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -__version__ = "2.5.0" +__version__ = "2.6.0" # fundamental supporting modules from brainpy import errors, check, tools From 3e5dcbf30aeac0b062e9a69a39900a6cef8cd1bb Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 24 Mar 2024 10:26:29 +0800 Subject: [PATCH 20/21] Upgrade CI (#658) * version 2.6.0 * upgrade CI --- .github/workflows/CI-models.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/CI-models.yml b/.github/workflows/CI-models.yml index 2883600b..af52b817 100644 --- a/.github/workflows/CI-models.yml +++ b/.github/workflows/CI-models.yml @@ -32,6 +32,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + python -m pip install --upgrade pip if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install @@ -79,6 +80,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + python -m pip install --upgrade pip if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi pip uninstall brainpy -y python setup.py install @@ -127,6 +129,7 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | + python -m pip install --upgrade pip python -m pip install numpy>=1.21.0 python -m pip install -r requirements-dev.txt python -m pip install tqdm brainpylib From 87858c54b4f4e45192ad4d9c6ff359f4c1e7ecf8 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Mon, 25 Mar 2024 12:39:51 +0800 Subject: [PATCH 21/21] [doc] Add Chinese version of `operator_custom_with_cupy.ipynb` and Rename it's title (#659) * Update operator_custom_with_cupy.ipynb * Update operator_custom_with_cupy.ipynb * Update operator_custom_with_cupy.ipynb --- .../operator_custom_with_cupy.ipynb | 156 +++++++++++++++++- 1 file changed, 152 insertions(+), 4 deletions(-) diff --git a/docs/tutorial_advanced/operator_custom_with_cupy.ipynb b/docs/tutorial_advanced/operator_custom_with_cupy.ipynb index 0b4bf241..19f302bb 100644 --- a/docs/tutorial_advanced/operator_custom_with_cupy.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_cupy.ipynb @@ -4,7 +4,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# CPU and GPU Operator Customization with CuPy\n", + "# GPU Operator Customization with CuPy\n", "\n", "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)\n", "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)" @@ -29,8 +29,8 @@ "metadata": {}, "source": [ "Although we can now use the flexible taichi custom operator approach, taichi on cuda does not have more fine-grained control or optimization for some scenarios. So for such scenarios, we can use cupy's \n", - "- `RawModule`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n", - "- `jit.rawkernel`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n", + "- [`RawModule`](https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n", + "- [`jit.rawkernel`](https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n", "\n", "to compile and run CUDA native code directly as strings or cupy JIT function in real time for finer grained control.\n", "\n", @@ -90,7 +90,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "After define the `RawModule` and get the kernel function. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)." + "After define the `RawModule` and get the kernel function. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**).\n", + "\n", + "Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output." ] }, { @@ -162,6 +164,152 @@ "# call the custom op\n", "y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 中文版\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "尽管我们现在可以使用灵活的taichi自定义操作符方法,但在cuda后端上,taichi没有更细粒度的控制或某些场景下的优化。因此,对于这类场景,我们可以使用cupy的\n", + "- [`RawModule`](https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n", + "- [`jit.rawkernel`](https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n", + "\n", + "来直接作为字符串或cupy JIT函数实时编译并运行CUDA原生代码,以实现更细致的控制。\n", + "\n", + "首先,导入相关的Python包。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import brainpy.math as bm\n", + "\n", + "import jax\n", + "import cupy as cp\n", + "from cupyx import jit\n", + "\n", + "bm.set_platform('gpu')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuPy RawModule\n", + "`RawModule`类可以通过传入CUDA源码的字符串来初始化,然后,通过调用`get_function()`方法可以检索所需的kernel,该方法返回一个可以调用的RawKernel实例。\n", + "\n", + "请注意,您想要调用的kernel中的参数顺序应该**将输出参数放在参数列表的末尾**。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source_code = '''\n", + " extern \"C\"{\n", + "\n", + " __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)\n", + " {\n", + " unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;\n", + " if (tid < N)\n", + " {\n", + " y[tid] = x1[tid] + x2[tid];\n", + " }\n", + " }\n", + " }\n", + "'''\n", + "mod = cp.RawModule(code=source_code)\n", + "kernel = mod.get_function('kernel')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "定义了RawModule并获取了内核函数后,可以使用`bm.XLACustomOp`将其注册到其`gpu_kernel`中,并使用您想要的适当的`grid`和`block`调用它(**这里这两个参数都应该是元组**)。\n", + "\n", + "最后在调用中指定`outs`参数,用`jax.ShapeDtypeStruct`来指定输出的形状和数据类型。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 准备输入\n", + "N = 10\n", + "x1 = bm.ones((N, N))\n", + "x2 = bm.ones((N, N))\n", + "\n", + "# 将kernel注册为自定义算子\n", + "prim1 = bm.XLACustomOp(gpu_kernel=kernel)\n", + "\n", + "# 调用自定义算子\n", + "y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuPy JIT RawKernel\n", + "\n", + "`cupyx.jit.rawkernel`装饰器可以从Python函数创建原生CUDA内核。\n", + "\n", + "以下是一个如何通过`cupyx.jit.rawkernel`来使用grid-stride循环从`x`复制值到`y`的简短示例:\n", + "\n", + "在GPU上启动CUDA内核,需要预先确定的grid/block大小,这需要对CUDA编程模型有基本的了解。编译将延迟到第一次函数调用时。CuPy的JIT编译器会在调用时推断参数的类型,并将缓存编译后的内核以加速任何后续调用。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jit.rawkernel()\n", + "def elementwise_copy(x, size, y):\n", + " tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x\n", + " ntid = jit.gridDim.x * jit.blockDim.x\n", + " for i in range(tid, size, ntid):\n", + " y[i] = x[i]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "定义了`jit.rawkernel`后,您可以使用`bm.XLACustomOp`将其注册到其`gpu_kernel`中,并使用您想要的适当的`grid`和`block`调用它(**这里这两个参数都应该是元组**)。" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# 准备输入\n", + "size = 100\n", + "x = bm.ones((size,))\n", + "\n", + "# 将kernel注册为自定义算子\n", + "prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)\n", + "\n", + "# 调用自定义算子\n", + "y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]\n" + ] } ], "metadata": {