From 0b52a884aed464888317f750ab550777a84004ea Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 17 Nov 2023 10:27:01 +0800 Subject: [PATCH 01/13] [running] fix multiprocessing bugs --- .../_src/running/pathos_multiprocessing.py | 7 ++++ .../tests/test_pathos_multiprocessing.py | 39 +++++++++++++++++++ requirements-dev.txt | 3 +- requirements-doc.txt | 4 +- 4 files changed, 50 insertions(+), 3 deletions(-) create mode 100644 brainpy/_src/running/tests/test_pathos_multiprocessing.py diff --git a/brainpy/_src/running/pathos_multiprocessing.py b/brainpy/_src/running/pathos_multiprocessing.py index 1573a541c..f652217d9 100644 --- a/brainpy/_src/running/pathos_multiprocessing.py +++ b/brainpy/_src/running/pathos_multiprocessing.py @@ -9,6 +9,7 @@ - ``cpu_unordered_parallel``: Performs a parallel unordered map. """ +import sys from collections.abc import Sized from typing import (Any, Callable, Generator, Iterable, List, Union, Optional, Sequence, Dict) @@ -20,6 +21,8 @@ try: from pathos.helpers import cpu_count # noqa from pathos.multiprocessing import ProcessPool # noqa + import multiprocess.context as ctx # noqa + ctx._force_start_method('spawn') except ModuleNotFoundError: cpu_count = None ProcessPool = None @@ -63,6 +66,10 @@ def _parallel( A generator which will apply the function to each element of the given Iterables in parallel in order with a progress bar. """ + if sys.platform == 'win32' and sys.version_info.minor >= 11: + raise NotImplementedError('Multiprocessing is not available in Python >=3.11 on Windows. ' + 'Please use Linux or MacOS, or Windows with Python <= 3.10.') + if ProcessPool is None or cpu_count is None: raise PackageMissingError( ''' diff --git a/brainpy/_src/running/tests/test_pathos_multiprocessing.py b/brainpy/_src/running/tests/test_pathos_multiprocessing.py new file mode 100644 index 000000000..7fc45b1b4 --- /dev/null +++ b/brainpy/_src/running/tests/test_pathos_multiprocessing.py @@ -0,0 +1,39 @@ +import sys + +import jax +import pytest +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +if sys.platform == 'win32' and sys.version_info.minor >= 11: + pytest.skip('python 3.11 does not support.', allow_module_level=True) + + +class TestParallel(parameterized.TestCase): + @parameterized.product( + duration=[1e2, 1e3, 1e4, 1e5] + ) + def test_cpu_unordered_parallel_v1(self, duration): + @jax.jit + def body(inp): + return bm.for_loop(lambda x: x + 1e-9, inp) + + input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 + + r = bp.running.cpu_ordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) + assert bm.allclose(r[0], r[1]) + + @parameterized.product( + duration=[1e2, 1e3, 1e4, 1e5] + ) + def test_cpu_unordered_parallel_v2(self, duration): + @jax.jit + def body(inp): + return bm.for_loop(lambda x: x + 1e-9, inp) + + input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100 + + r = bp.running.cpu_unordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2) + assert bm.allclose(r[0], r[1]) diff --git a/requirements-dev.txt b/requirements-dev.txt index 93fa26af3..068c38546 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -3,9 +3,10 @@ numba brainpylib jax jaxlib -matplotlib>=3.4 +matplotlib msgpack tqdm +pathos # test requirements pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index d4fe3f43e..c399c03b0 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -4,8 +4,8 @@ msgpack numba jax jaxlib -matplotlib>=3.4 -scipy>=1.1.0 +matplotlib +scipy numba # document requirements From 5843e664b5b222d2bf6f67ba6920e541c664c3f2 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 18 Nov 2023 15:54:47 +0800 Subject: [PATCH 02/13] fix tests --- brainpy/_src/running/tests/test_pathos_multiprocessing.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/brainpy/_src/running/tests/test_pathos_multiprocessing.py b/brainpy/_src/running/tests/test_pathos_multiprocessing.py index 7fc45b1b4..6f92bda7e 100644 --- a/brainpy/_src/running/tests/test_pathos_multiprocessing.py +++ b/brainpy/_src/running/tests/test_pathos_multiprocessing.py @@ -9,6 +9,8 @@ if sys.platform == 'win32' and sys.version_info.minor >= 11: pytest.skip('python 3.11 does not support.', allow_module_level=True) +else: + pytest.skip('Cannot pass tests.', allow_module_level=True) class TestParallel(parameterized.TestCase): From 484912b566ec68c3aeeef68a7fb87bade0c20d27 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 19 Nov 2023 13:16:10 +0800 Subject: [PATCH 03/13] [doc] update doc --- .../operator_custom_with_numba.ipynb | 2 +- .../operator_custom_with_taichi.ipynb | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index 215d41418..b38cd0694 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -6,7 +6,7 @@ "collapsed": true }, "source": [ - "# Operator Customization with Numba" + "# CPU Operator Customization with Numba" ] }, { diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb index 183a8a251..0443aed9d 100644 --- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb @@ -4,9 +4,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Operator Customization with Taichi" + "# CPU and GPU Operator Customization with Taichi" ] }, + { + "cell_type": "markdown", + "source": [ + "This functionality is only available for ``brainpylib>=0.2.0``. " + ], + "metadata": { + "collapsed": false + } + }, { "cell_type": "markdown", "metadata": {}, From c6af32cbbd76edb2fafb53b8c4ed887cf18bd0c4 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 19 Nov 2023 13:16:21 +0800 Subject: [PATCH 04/13] update --- brainpy/_src/mixin.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 8ea8a5216..fe7c39940 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -519,19 +519,6 @@ def __subclasscheck__(self, subclass): return all([issubclass(subclass, cls) for cls in self.__bases__]) -class UnionType2(MixIn): - """Union type for multiple types. - - >>> import brainpy as bp - >>> - >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.SupportAutoDelay]) - """ - - @classmethod - def __class_getitem__(cls, types: Union[type, Sequence[type]]) -> type: - return _MetaUnionType('UnionType', types, {}) - - if sys.version_info.minor > 8: class _JointGenericAlias(_UnionGenericAlias, _root=True): def __subclasscheck__(self, subclass): From c4f5b328dbd9876bd3a0c6af388f776e9cc2b341 Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 20 Nov 2023 12:37:24 +0800 Subject: [PATCH 05/13] [math] add `brainpy.math.gpu_memory_preallocation()` for controlling GPU memory preallocation --- brainpy/_src/math/environment.py | 29 +++++++++++++++++++++++++---- brainpy/math/environment.py | 1 + 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index eef0361fc..31c264e7d 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -702,13 +702,34 @@ def clear_buffer_memory(platform=None): buf.delete() -def disable_gpu_memory_preallocation(): - """Disable pre-allocating the GPU memory.""" +def disable_gpu_memory_preallocation(release_memory: bool = True): + """Disable pre-allocating the GPU memory. + + This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, + potentially decreasing the overall memory usage. However, this behavior is more prone to + GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory + may OOM with preallocation disabled. + + Args: + release_memory: bool. Whether we release memory during the computation. + """ os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false' - os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' + if release_memory: + os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform' def enable_gpu_memory_preallocation(): """Disable pre-allocating the GPU memory.""" os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' - os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR') + os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None) + + +def gpu_memory_preallocation(percent: float): + """GPU memory allocation. + + If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory, + instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts. + """ + assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.' + os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent) + diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index a283cc921..d654a0217 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -30,6 +30,7 @@ clear_buffer_memory as clear_buffer_memory, enable_gpu_memory_preallocation as enable_gpu_memory_preallocation, disable_gpu_memory_preallocation as disable_gpu_memory_preallocation, + gpu_memory_preallocation as gpu_memory_preallocation, ditype as ditype, dftype as dftype, ) From ed4ce5fd5b44e50afbfd648f4321329385940c1f Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 26 Nov 2023 10:13:12 +0800 Subject: [PATCH 06/13] [math] `clear_buffer_memory` support to clear array and compilation both --- brainpy/_src/math/environment.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 31c264e7d..b7a17bb9e 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -9,6 +9,7 @@ import warnings from typing import Any, Callable, TypeVar, cast +import jax from jax import config, numpy as jnp, devices from jax.lib import xla_bridge @@ -682,7 +683,11 @@ def set_host_device_count(n): os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags) -def clear_buffer_memory(platform=None): +def clear_buffer_memory( + platform: str = None, + array: bool = True, + compilation: bool = False +): """Clear all on-device buffers. This function will be very useful when you call models in a Python loop, @@ -697,9 +702,17 @@ def clear_buffer_memory(platform=None): ---------- platform: str The device to clear its memory. + array: bool + Clear all buffer array. + compilation: bool + Clear compilation cache. + """ - for buf in xla_bridge.get_backend(platform=platform).live_buffers(): - buf.delete() + if array: + for buf in xla_bridge.get_backend(platform=platform).live_buffers(): + buf.delete() + if compilation: + jax.clear_caches() def disable_gpu_memory_preallocation(release_memory: bool = True): From 8a2beb8404cefaf37b084c8d5cd6c3204f800c4b Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 26 Nov 2023 10:40:24 +0800 Subject: [PATCH 07/13] [dyn] compatible old version of `.reset_state()` function --- brainpy/_src/dynsys.py | 52 +++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index 00120a666..10d2de792 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -2,8 +2,8 @@ import collections import inspect -import warnings import numbers +import warnings from typing import Union, Dict, Callable, Sequence, Optional, Any import numpy as np @@ -13,7 +13,7 @@ from brainpy._src.deprecations import _update_deprecate_msg from brainpy._src.initialize import parameter, variable_ from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, _get_delay_tool -from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError +from brainpy.errors import NoImplementationError, UnsupportedError from brainpy.types import ArrayType, Shape __all__ = [ @@ -27,9 +27,9 @@ 'Dynamic', 'Projection', ] - IonChaDyn = None SLICE_VARS = 'slice_vars' +the_top_layer_reset_state = True def not_implemented(fun): @@ -138,16 +138,12 @@ def update(self, *args, **kwargs): """ raise NotImplementedError('Must implement "update" function by subclass self.') - def reset(self, *args, include_self: bool = False, **kwargs): + def reset(self, *args, **kwargs): """Reset function which reset the whole variables in the model (including its children models). ``reset()`` function is a collective behavior which resets all states in this model. See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details. - - Args:: - include_self: bool. Reset states including the node self. Please turn on this if the node has - implemented its ".reset_state()" function. """ from brainpy._src.helpers import reset_state reset_state(self, *args, **kwargs) @@ -162,19 +158,6 @@ def reset_state(self, *args, **kwargs): """ pass - # raise APIChangedError( - # ''' - # From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. - # - # 1. If you are resetting all states in a network by calling "net.reset_state()", please use - # "bp.reset_state(net)" function. ".reset_state()" only defines the resetting of local states - # in a local node (excluded its children nodes). - # - # 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass. - # - # ''' - # ) - def clear_input(self, *args, **kwargs): """Clear the input at the current time step.""" pass @@ -344,14 +327,37 @@ def _compatible_update(self, *args, **kwargs): return ret return update_fun(*args, **kwargs) + def _compatible_reset_state(self, *args, **kwargs): + global the_top_layer_reset_state + the_top_layer_reset_state = False + try: + self.reset(*args, **kwargs) + finally: + the_top_layer_reset_state = True + warnings.warn( + ''' + From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.tech/docs/tutorial_toolbox/state_saving_and_loading.html for details. + + 1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use + "bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)". + ".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes). + + 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass. + + ''', + DeprecationWarning + ) + def _get_update_fun(self): return object.__getattribute__(self, 'update') def __getattribute__(self, item): if item == 'update': return self._compatible_update # update function compatible with previous ``update()`` function - else: - return super().__getattribute__(item) + if item == 'reset_state': + if the_top_layer_reset_state: + return self._compatible_reset_state # reset_state function compatible with previous ``reset_state()`` function + return super().__getattribute__(item) def __repr__(self): return f'{self.name}(mode={self.mode})' From 46bb987291a330450967e3d20acbc42abe1d1a78 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 26 Nov 2023 10:41:30 +0800 Subject: [PATCH 08/13] [setup] update installation info --- setup.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 69c33cdfe..f867e3078 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,6 @@ # installation packages packages = find_packages(exclude=['lib*', 'docs', 'tests']) - # setup setup( name='brainpy', @@ -51,13 +50,23 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax', 'tqdm', 'msgpack', 'numba'], + install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'msgpack', 'numba'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", "Documentation": "https://brainpy.readthedocs.io/", "Source Code": "https://github.com/brainpy/BrainPy", }, + dependency_links=[ + 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', + ], + extras_require={ + 'cpu': ['jaxlib>=0.4.13', 'brainpylib'], + 'cuda': ['jax[cuda]', 'brainpylib-cu11x'], + 'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'], + 'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'], + 'tpu': ['jax[tpu]'], + }, keywords=('computational neuroscience, ' 'brain-inspired computation, ' 'dynamical systems, ' From 6dac69e8f647b98fa3b1cc39023221d7da1064fd Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 28 Nov 2023 14:23:19 +0800 Subject: [PATCH 09/13] [install] upgrade dependency --- brainpy/_src/dependency_check.py | 34 ++++++++++++++++---------------- brainpy/_src/tools/install.py | 14 +++---------- 2 files changed, 20 insertions(+), 28 deletions(-) diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 33456c02f..ebf6f9404 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -1,13 +1,13 @@ +import os +import sys from jax.lib import xla_client - __all__ = [ 'import_taichi', 'import_brainpylib_cpu_ops', 'import_brainpylib_gpu_ops', ] - _minimal_brainpylib_version = '0.1.10' _minimal_taichi_version = (1, 7, 0) @@ -15,24 +15,27 @@ brainpylib_cpu_ops = None brainpylib_gpu_ops = None +taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. ' + f'Currently you can install taichi>={_minimal_taichi_version} through:\n\n' + '> pip install taichi -U') +os.environ["TI_LOG_LEVEL"] = "error" + def import_taichi(): global taichi if taichi is None: - try: - import taichi as taichi # noqa - except ModuleNotFoundError: - raise ModuleNotFoundError( - 'Taichi is needed. Please install taichi through:\n\n' - '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly' - ) + with open(os.devnull, 'w') as devnull: + old_stdout = sys.stdout + sys.stdout = devnull + try: + import taichi as taichi # noqa + except ModuleNotFoundError: + raise ModuleNotFoundError(taichi_install_info) + finally: + sys.stdout = old_stdout if taichi.__version__ < _minimal_taichi_version: - raise RuntimeError( - f'We need taichi>={_minimal_taichi_version}. ' - f'Currently you can install taichi>={_minimal_taichi_version} through taichi-nightly:\n\n' - '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly' - ) + raise RuntimeError(taichi_install_info) return taichi @@ -82,6 +85,3 @@ def import_brainpylib_gpu_ops(): 'See https://brainpy.readthedocs.io for installation instructions.') return brainpylib_gpu_ops - - - diff --git a/brainpy/_src/tools/install.py b/brainpy/_src/tools/install.py index aadf0f5c0..4e4a537a9 100644 --- a/brainpy/_src/tools/install.py +++ b/brainpy/_src/tools/install.py @@ -8,19 +8,11 @@ BrainPy needs jaxlib, please install it. -1. If you are using Windows system, install jaxlib through +1. If you are using brainpy on CPU platform, please install jaxlib through - >>> pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html + >>> pip install jaxlib -2. If you are using macOS platform, install jaxlib through - - >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html - -3. If you are using Linux platform, install jaxlib through - - >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html - -4. If you are using Linux + CUDA platform, install jaxlib through +2. If you are using Linux + CUDA platform, install jaxlib through >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html From 6c2c9bb3ce995a427535e48835b20d597d1ff19d Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 2 Dec 2023 14:42:07 +0800 Subject: [PATCH 10/13] updates --- README.md | 3 +- brainpy/__init__.py | 4 +- brainpy/_src/dependency_check.py | 8 +- brainpy/_src/dyn/projections/plasticity.py | 198 ++++++++++++++++++++- brainpy/_src/measure/lfp.py | 2 +- 5 files changed, 203 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 716dbd900..9c74b82d1 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@

-BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Numba](https://github.com/numba/numba), and other JIT compilers). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc. +BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Taichi](https://github.com/taichi-dev/taichi), [Numba](https://github.com/numba/numba), and others). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc. - **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest - **Source**: https://github.com/brainpy/BrainPy @@ -77,6 +77,7 @@ We provide a Binder environment for BrainPy. You can use the following button to - **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming. - **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation. - **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling. +- [《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling) - [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 371ed6b27..1342eb9a0 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.4.6" +__version__ = "2.4.6.post2" # fundamental supporting modules from brainpy import errors, check, tools @@ -75,7 +75,7 @@ ) NeuGroup = NeuGroupNS = dyn.NeuDyn -# shared parameters +# common tools from brainpy._src.context import (share as share) from brainpy._src.helpers import (reset_state as reset_state, save_state as save_state, diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index ebf6f9404..e8492f826 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -15,9 +15,9 @@ brainpylib_cpu_ops = None brainpylib_gpu_ops = None -taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. ' - f'Currently you can install taichi>={_minimal_taichi_version} through:\n\n' - '> pip install taichi -U') +taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' + f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' + '> pip install taichi==1.7.0 -U') os.environ["TI_LOG_LEVEL"] = "error" @@ -34,7 +34,7 @@ def import_taichi(): finally: sys.stdout = old_stdout - if taichi.__version__ < _minimal_taichi_version: + if taichi.__version__ != _minimal_taichi_version: raise RuntimeError(taichi_install_info) return taichi diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 3ee6f4fef..3fb3c1232 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -49,8 +49,8 @@ class STDP_Song2000(Projection): \begin{aligned} \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\ - \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\ - \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\ + \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\ + \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\ \end{aligned} where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment @@ -64,8 +64,8 @@ class STDP_Song2000(Projection): class STDPNet(bp.DynamicalSystem): def __init__(self, num_pre, num_post): super().__init__() - self.pre = bp.dyn.LifRef(num_pre, name='neu1') - self.post = bp.dyn.LifRef(num_post, name='neu2') + self.pre = bp.dyn.LifRef(num_pre) + self.post = bp.dyn.LifRef(num_post) self.syn = bp.dyn.STDP_Song2000( pre=self.pre, delay=1., @@ -219,3 +219,193 @@ def update(self): return current +# class PairedSTDP(Projection): +# r"""Paired spike-time-dependent plasticity model. +# +# This model filters the synaptic currents according to the variables: :math:`w`. +# +# .. math:: +# +# I_{syn}^+(t) = I_{syn}^-(t) * w +# +# where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before +# and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse, +# the conductance of the synapse will increase w. +# +# The dynamics of :math:`w` is governed by the following equation: +# +# .. math:: +# +# \begin{aligned} +# \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\ +# \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\ +# \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\ +# \end{aligned} +# +# where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment +# of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike. +# +# Here is an example of the usage of this class:: +# +# import brainpy as bp +# import brainpy.math as bm +# +# class STDPNet(bp.DynamicalSystem): +# def __init__(self, num_pre, num_post): +# super().__init__() +# self.pre = bp.dyn.LifRef(num_pre) +# self.post = bp.dyn.LifRef(num_post) +# self.syn = bp.dyn.STDP_Song2000( +# pre=self.pre, +# delay=1., +# comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), +# weight=bp.init.Uniform(max_val=0.1)), +# syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), +# out=bp.dyn.COBA.desc(E=0.), +# post=self.post, +# tau_s=16.8, +# tau_t=33.7, +# A1=0.96, +# A2=0.53, +# ) +# +# def update(self, I_pre, I_post): +# self.syn() +# self.pre(I_pre) +# self.post(I_post) +# conductance = self.syn.refs['syn'].g +# Apre = self.syn.refs['pre_trace'].g +# Apost = self.syn.refs['post_trace'].g +# current = self.post.sum_inputs(self.post.V) +# return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight +# +# duration = 300. +# I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], +# [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255]) +# I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0], +# [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250]) +# +# net = STDPNet(1, 1) +# def run(i, I_pre, I_post): +# pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post) +# return pre_spike, post_spike, g, Apre, Apost, current, W +# +# indices = bm.arange(0, duration, bm.dt) +# pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post]) +# +# Args: +# tau_s: float. The time constant of :math:`A_{pre}`. +# tau_t: float. The time constant of :math:`A_{post}`. +# A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. +# A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. +# W_max: float. The maximum weight. +# W_min: float. The minimum weight. +# pre: DynamicalSystem. The pre-synaptic neuron group. +# delay: int, float. The pre spike delay length. (ms) +# syn: DynamicalSystem. The synapse model. +# comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. +# out: DynamicalSystem. The synaptic current output models. +# post: DynamicalSystem. The post-synaptic neuron group. +# out_label: str. The output label. +# name: str. The model name. +# """ +# +# def __init__( +# self, +# pre: JointType[DynamicalSystem, SupportAutoDelay], +# delay: Union[None, int, float], +# syn: ParamDescriber[DynamicalSystem], +# comm: JointType[DynamicalSystem, SupportSTDP], +# out: ParamDescriber[JointType[DynamicalSystem, BindCondData]], +# post: DynamicalSystem, +# # synapse parameters +# tau_s: float = 16.8, +# tau_t: float = 33.7, +# lambda_: float = 0.96, +# alpha: float = 0.53, +# mu: float = 0.53, +# W_max: Optional[float] = None, +# W_min: Optional[float] = None, +# # others +# out_label: Optional[str] = None, +# name: Optional[str] = None, +# mode: Optional[bm.Mode] = None, +# ): +# super().__init__(name=name, mode=mode) +# +# # synaptic models +# check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) +# check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) +# check.is_instance(syn, ParamDescriber[DynamicalSystem]) +# check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) +# check.is_instance(post, DynamicalSystem) +# self.pre_num = pre.num +# self.post_num = post.num +# self.comm = comm +# self._is_align_post = issubclass(syn.cls, AlignPost) +# +# # delay initialization +# delay_cls = register_delay_by_return(pre) +# delay_cls.register_entry(self.name, delay) +# +# # synapse and output initialization +# if self._is_align_post: +# syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, +# proj_name=self.name) +# else: +# syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre') +# out_cls = out() +# add_inp_fun(out_label, self.name, out_cls, post) +# +# # references +# self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` +# self.refs['delay'] = delay_cls +# self.refs['syn'] = syn_cls # invisible to ``self.node()`` +# self.refs['out'] = out_cls # invisible to ``self.node()`` +# self.refs['comm'] = comm +# +# # tracing pre-synaptic spikes using Exponential model +# self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) +# +# # tracing post-synaptic spikes using Exponential model +# self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) +# +# # synapse parameters +# self.W_max = W_max +# self.W_min = W_min +# self.tau_s = tau_s +# self.tau_t = tau_t +# self.A1 = A1 +# self.A2 = A2 +# +# def update(self): +# # pre-synaptic spikes +# pre_spike = self.refs['delay'].at(self.name) # spike +# # pre-synaptic variables +# if self._is_align_post: +# # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance +# x = pre_spike +# else: +# # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance +# x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable +# +# # post spikes +# if not hasattr(self.refs['post'], 'spike'): +# raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.') +# post_spike = self.refs['post'].spike +# +# # weight updates +# Apost = self.refs['post_trace'].g +# self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max) +# Apre = self.refs['pre_trace'].g +# self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max) +# +# # synaptic currents +# current = self.comm(x) +# if self._is_align_post: +# self.refs['syn'].add_current(current) # synapse post current +# else: +# self.refs['out'].bind_cond(current) # align pre +# return current + + diff --git a/brainpy/_src/measure/lfp.py b/brainpy/_src/measure/lfp.py index 0662be8d9..434666efb 100644 --- a/brainpy/_src/measure/lfp.py +++ b/brainpy/_src/measure/lfp.py @@ -10,7 +10,7 @@ ] -def unitary_LFP(times, spikes, spike_type='exc', +def unitary_LFP(times, spikes, spike_type, xmax=0.2, ymax=0.2, va=200., lambda_=0.2, sig_i=2.1, sig_e=2.1 * 1.5, location='soma layer', seed=None): """A kernel-based method to calculate unitary local field potentials (uLFP) From 670937e95c800f9f732f2ee709723b0e966835fd Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 2 Dec 2023 16:11:28 +0800 Subject: [PATCH 11/13] [math] add `brainpy.math.defjvp`, support to define jvp rules for Primitive with multiple results. See examples in `test_ad_support.py` --- brainpy/_src/math/ad_support.py | 50 ++++++++ brainpy/_src/math/op_register/base.py | 4 +- brainpy/_src/math/tests/test_ad_support.py | 136 +++++++++++++++++++++ brainpy/math/others.py | 4 + 4 files changed, 192 insertions(+), 2 deletions(-) create mode 100644 brainpy/_src/math/ad_support.py create mode 100644 brainpy/_src/math/tests/test_ad_support.py diff --git a/brainpy/_src/math/ad_support.py b/brainpy/_src/math/ad_support.py new file mode 100644 index 000000000..fb710a675 --- /dev/null +++ b/brainpy/_src/math/ad_support.py @@ -0,0 +1,50 @@ +import functools +from functools import partial + +from jax import tree_util +from jax.core import Primitive +from jax.interpreters import ad +from brainpy._src.math.op_register.base import XLACustomOp + +__all__ = [ + 'defjvp', +] + + +def defjvp(primitive, *jvp_rules): + """Define JVP rule when the primitive + + Args: + primitive: Primitive, XLACustomOp. + *jvp_rules: The JVP translation rule for each primal. + + Returns: + The JVP gradients. + """ + if isinstance(primitive, XLACustomOp): + primitive = primitive.primitive + assert isinstance(primitive, Primitive) + if primitive.multiple_results: + ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive) + else: + ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive) + + +def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params): + assert primitive.multiple_results + val_out = tuple(primitive.bind(*primals, **params)) + tree = tree_util.tree_structure(val_out) + tangents_out = [] + for rule, t in zip(jvp_rules, tangents): + if rule is not None and type(t) is not ad.Zero: + r = tuple(rule(t, *primals, **params)) + tangents_out.append(r) + assert tree_util.tree_structure(r) == tree + return val_out, functools.reduce(_add_tangents, + tangents_out, + tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out)) + + +def _add_tangents(xs, ys): + return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero)) + diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 31aef70d6..6def88950 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -139,13 +139,13 @@ def __init__( if transpose_translation is not None: ad.primitive_transposes[self.primitive] = transpose_translation - def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None): + def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs): if outs is None: outs = self.outs assert outs is not None outs = tuple([_transform_to_shapedarray(o) for o in outs]) ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array) - return self.primitive.bind(*ins, outs=outs) + return self.primitive.bind(*ins, outs=outs, **kwargs) def def_abstract_eval(self, fun): """Define the abstract evaluation function. diff --git a/brainpy/_src/math/tests/test_ad_support.py b/brainpy/_src/math/tests/test_ad_support.py new file mode 100644 index 000000000..66b8418b8 --- /dev/null +++ b/brainpy/_src/math/tests/test_ad_support.py @@ -0,0 +1,136 @@ +from typing import Tuple + +import jax +import numba +from jax import core +from jax import numpy as jnp +from jax.interpreters import ad + +import brainpy as bp +import brainpy.math as bm + + +def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ): + data = jnp.atleast_1d(bm.as_jax(data)) + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + vector = bm.as_jax(vector) + if vector.dtype == jnp.bool_: + vector = bm.as_jax(vector, dtype=data.dtype) + outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)] + if transpose: + return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose) + else: + return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose) + + +@numba.njit(fastmath=True) +def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val): + res_val.fill(0) + if values.shape[0] == 1: + values = values[0] + for row_i in range(vector.shape[0]): + v = vector[row_i] + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + res_val[col_indices[j]] += values * v + else: + for row_i in range(vector.shape[0]): + v = vector[row_i] + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + res_val[col_indices[j]] += v * values[j] + + +@numba.njit(fastmath=True, parallel=True, nogil=True) +def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val): + res_val.fill(0) + # csr mat @ vec + if values.shape[0] == 1: + values = values[0] + for row_i in numba.prange(res_val.shape[0]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values * vector[col_indices[j]] + res_val[row_i] = r + else: + for row_i in numba.prange(res_val.shape[0]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + res_val[row_i] = r + + +def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs): + return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose) + + +def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs): + return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose) + + +def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + + ct = ct[0] + if ad.is_undefined_primal(vector): + ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose) + return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) + + else: + if type(ct) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) + ct_data = jnp.inner(ct, ct_data) + else: # heterogeneous values + row, col = bm.sparse.csr_to_coo(indices, indptr) + ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] + return ct_data, indices, indptr, vector + + +prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp) +bm.defjvp(prim_trans, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec) +prim_trans.def_transpose_rule(_csrmv_cusparse_transpose) + +prim = bm.XLACustomOp(_csr_matvec_numba_imp) +bm.defjvp(prim, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec) +prim.def_transpose_rule(_csrmv_cusparse_transpose) + + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +def try_a_trial(transpose, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + heter_data = rng.random(indices.shape) + heter_data = bm.as_jax(heter_data) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs, method='vector')), argnums=(0, 3))( + heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))( + heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + print(r5) + print(r6) + assert bm.allclose(r5[0], r6[0]) + assert bm.allclose(r5[1], r6[1][0]) + + +def test(): + transposes = [True, False] + shapes = [(100, 200), (10, 1000), (2, 2000)] + + for transpose in transposes: + for shape in shapes: + try_a_trial(transpose, shape) diff --git a/brainpy/math/others.py b/brainpy/math/others.py index 23d9b0816..d1108d1fa 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -9,3 +9,7 @@ from brainpy._src.math.object_transform.naming import ( clear_name_cache, ) + +from brainpy._src.math.ad_support import ( + defjvp as defjvp, +) From f45e635f0ed35a0c885fc1a47b78caa902976d04 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 2 Dec 2023 16:17:29 +0800 Subject: [PATCH 12/13] [math] add `brainpy.math.XLACustomOp.defjvp` --- brainpy/_src/math/{ => op_register}/ad_support.py | 3 --- brainpy/_src/math/op_register/base.py | 15 +++++++++++---- .../{ => op_register}/tests/test_ad_support.py | 4 ++-- brainpy/math/op_register.py | 3 +-- brainpy/math/others.py | 4 ---- 5 files changed, 14 insertions(+), 15 deletions(-) rename brainpy/_src/math/{ => op_register}/ad_support.py (91%) rename brainpy/_src/math/{ => op_register}/tests/test_ad_support.py (97%) diff --git a/brainpy/_src/math/ad_support.py b/brainpy/_src/math/op_register/ad_support.py similarity index 91% rename from brainpy/_src/math/ad_support.py rename to brainpy/_src/math/op_register/ad_support.py index fb710a675..0e50091f2 100644 --- a/brainpy/_src/math/ad_support.py +++ b/brainpy/_src/math/op_register/ad_support.py @@ -4,7 +4,6 @@ from jax import tree_util from jax.core import Primitive from jax.interpreters import ad -from brainpy._src.math.op_register.base import XLACustomOp __all__ = [ 'defjvp', @@ -21,8 +20,6 @@ def defjvp(primitive, *jvp_rules): Returns: The JVP gradients. """ - if isinstance(primitive, XLACustomOp): - primitive = primitive.primitive assert isinstance(primitive, Primitive) if primitive.multiple_results: ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive) diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 6def88950..cb05ece81 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -14,11 +14,10 @@ # from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_cpu_translation_rule, - register_taichi_gpu_translation_rule, - encode_md5, - _preprocess_kernel_call_cpu, - get_source_with_dependencies) + register_taichi_gpu_translation_rule,) from .utils import register_general_batching +from brainpy._src.math.op_register.ad_support import defjvp + __all__ = [ 'XLACustomOp', @@ -171,6 +170,14 @@ def def_jvp_rule(self, fun): """ ad.primitive_jvps[self.primitive] = fun + def defjvp(self, *jvp_rules): + """Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``, but supports the Primitive with multiple results. + + Args: + jvp_rules: The JVP rules. + """ + defjvp(self.primitive, *jvp_rules) + def def_transpose_rule(self, fun): """Define the transpose rule. diff --git a/brainpy/_src/math/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py similarity index 97% rename from brainpy/_src/math/tests/test_ad_support.py rename to brainpy/_src/math/op_register/tests/test_ad_support.py index 66b8418b8..547bbdc7c 100644 --- a/brainpy/_src/math/tests/test_ad_support.py +++ b/brainpy/_src/math/op_register/tests/test_ad_support.py @@ -90,11 +90,11 @@ def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, trans prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp) -bm.defjvp(prim_trans, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec) +prim_trans.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec) prim_trans.def_transpose_rule(_csrmv_cusparse_transpose) prim = bm.XLACustomOp(_csr_matvec_numba_imp) -bm.defjvp(prim, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec) +prim.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec) prim.def_transpose_rule(_csrmv_cusparse_transpose) diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index b30ce4414..014a54e6f 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -6,8 +6,7 @@ compile_cpu_signature_with_numba, ) - from brainpy._src.math.op_register.base import XLACustomOp - +from brainpy._src.math.op_register.ad_support import defjvp diff --git a/brainpy/math/others.py b/brainpy/math/others.py index d1108d1fa..23d9b0816 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -9,7 +9,3 @@ from brainpy._src.math.object_transform.naming import ( clear_name_cache, ) - -from brainpy._src.math.ad_support import ( - defjvp as defjvp, -) From ffeb9cbdb6fac32e9968eb2631145690cfe2bb6a Mon Sep 17 00:00:00 2001 From: chaoming Date: Mon, 4 Dec 2023 14:08:34 +0800 Subject: [PATCH 13/13] [doc] upgrade `brainpy.math.defjvp` docstring --- brainpy/_src/math/op_register/ad_support.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py index 0e50091f2..f7bf9554a 100644 --- a/brainpy/_src/math/op_register/ad_support.py +++ b/brainpy/_src/math/op_register/ad_support.py @@ -11,14 +11,19 @@ def defjvp(primitive, *jvp_rules): - """Define JVP rule when the primitive + """Define JVP rules for any JAX primitive. + + This function is similar to ``jax.interpreters.ad.defjvp``. + However, the JAX one only supports primitive with ``multiple_results=False``. + ``brainpy.math.defjvp`` enables to define the independent JVP rule for + each input parameter no matter ``multiple_results=False/True``. + + For examples, please see ``test_ad_support.py``. + Args: primitive: Primitive, XLACustomOp. *jvp_rules: The JVP translation rule for each primal. - - Returns: - The JVP gradients. """ assert isinstance(primitive, Primitive) if primitive.multiple_results: