From 0b52a884aed464888317f750ab550777a84004ea Mon Sep 17 00:00:00 2001
From: chaoming
Date: Fri, 17 Nov 2023 10:27:01 +0800
Subject: [PATCH 01/13] [running] fix multiprocessing bugs
---
.../_src/running/pathos_multiprocessing.py | 7 ++++
.../tests/test_pathos_multiprocessing.py | 39 +++++++++++++++++++
requirements-dev.txt | 3 +-
requirements-doc.txt | 4 +-
4 files changed, 50 insertions(+), 3 deletions(-)
create mode 100644 brainpy/_src/running/tests/test_pathos_multiprocessing.py
diff --git a/brainpy/_src/running/pathos_multiprocessing.py b/brainpy/_src/running/pathos_multiprocessing.py
index 1573a541c..f652217d9 100644
--- a/brainpy/_src/running/pathos_multiprocessing.py
+++ b/brainpy/_src/running/pathos_multiprocessing.py
@@ -9,6 +9,7 @@
- ``cpu_unordered_parallel``: Performs a parallel unordered map.
"""
+import sys
from collections.abc import Sized
from typing import (Any, Callable, Generator, Iterable, List,
Union, Optional, Sequence, Dict)
@@ -20,6 +21,8 @@
try:
from pathos.helpers import cpu_count # noqa
from pathos.multiprocessing import ProcessPool # noqa
+ import multiprocess.context as ctx # noqa
+ ctx._force_start_method('spawn')
except ModuleNotFoundError:
cpu_count = None
ProcessPool = None
@@ -63,6 +66,10 @@ def _parallel(
A generator which will apply the function to each element of the given Iterables
in parallel in order with a progress bar.
"""
+ if sys.platform == 'win32' and sys.version_info.minor >= 11:
+ raise NotImplementedError('Multiprocessing is not available in Python >=3.11 on Windows. '
+ 'Please use Linux or MacOS, or Windows with Python <= 3.10.')
+
if ProcessPool is None or cpu_count is None:
raise PackageMissingError(
'''
diff --git a/brainpy/_src/running/tests/test_pathos_multiprocessing.py b/brainpy/_src/running/tests/test_pathos_multiprocessing.py
new file mode 100644
index 000000000..7fc45b1b4
--- /dev/null
+++ b/brainpy/_src/running/tests/test_pathos_multiprocessing.py
@@ -0,0 +1,39 @@
+import sys
+
+import jax
+import pytest
+from absl.testing import parameterized
+
+import brainpy as bp
+import brainpy.math as bm
+
+if sys.platform == 'win32' and sys.version_info.minor >= 11:
+ pytest.skip('python 3.11 does not support.', allow_module_level=True)
+
+
+class TestParallel(parameterized.TestCase):
+ @parameterized.product(
+ duration=[1e2, 1e3, 1e4, 1e5]
+ )
+ def test_cpu_unordered_parallel_v1(self, duration):
+ @jax.jit
+ def body(inp):
+ return bm.for_loop(lambda x: x + 1e-9, inp)
+
+ input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100
+
+ r = bp.running.cpu_ordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2)
+ assert bm.allclose(r[0], r[1])
+
+ @parameterized.product(
+ duration=[1e2, 1e3, 1e4, 1e5]
+ )
+ def test_cpu_unordered_parallel_v2(self, duration):
+ @jax.jit
+ def body(inp):
+ return bm.for_loop(lambda x: x + 1e-9, inp)
+
+ input_long = bm.random.randn(1, int(duration / bm.dt), 3) / 100
+
+ r = bp.running.cpu_unordered_parallel(body, {'inp': [input_long, input_long]}, num_process=2)
+ assert bm.allclose(r[0], r[1])
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 93fa26af3..068c38546 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -3,9 +3,10 @@ numba
brainpylib
jax
jaxlib
-matplotlib>=3.4
+matplotlib
msgpack
tqdm
+pathos
# test requirements
pytest
diff --git a/requirements-doc.txt b/requirements-doc.txt
index d4fe3f43e..c399c03b0 100644
--- a/requirements-doc.txt
+++ b/requirements-doc.txt
@@ -4,8 +4,8 @@ msgpack
numba
jax
jaxlib
-matplotlib>=3.4
-scipy>=1.1.0
+matplotlib
+scipy
numba
# document requirements
From 5843e664b5b222d2bf6f67ba6920e541c664c3f2 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sat, 18 Nov 2023 15:54:47 +0800
Subject: [PATCH 02/13] fix tests
---
brainpy/_src/running/tests/test_pathos_multiprocessing.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/brainpy/_src/running/tests/test_pathos_multiprocessing.py b/brainpy/_src/running/tests/test_pathos_multiprocessing.py
index 7fc45b1b4..6f92bda7e 100644
--- a/brainpy/_src/running/tests/test_pathos_multiprocessing.py
+++ b/brainpy/_src/running/tests/test_pathos_multiprocessing.py
@@ -9,6 +9,8 @@
if sys.platform == 'win32' and sys.version_info.minor >= 11:
pytest.skip('python 3.11 does not support.', allow_module_level=True)
+else:
+ pytest.skip('Cannot pass tests.', allow_module_level=True)
class TestParallel(parameterized.TestCase):
From 484912b566ec68c3aeeef68a7fb87bade0c20d27 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 19 Nov 2023 13:16:10 +0800
Subject: [PATCH 03/13] [doc] update doc
---
.../operator_custom_with_numba.ipynb | 2 +-
.../operator_custom_with_taichi.ipynb | 11 ++++++++++-
2 files changed, 11 insertions(+), 2 deletions(-)
diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb
index 215d41418..b38cd0694 100644
--- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb
+++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb
@@ -6,7 +6,7 @@
"collapsed": true
},
"source": [
- "# Operator Customization with Numba"
+ "# CPU Operator Customization with Numba"
]
},
{
diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb
index 183a8a251..0443aed9d 100644
--- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb
+++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb
@@ -4,9 +4,18 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# Operator Customization with Taichi"
+ "# CPU and GPU Operator Customization with Taichi"
]
},
+ {
+ "cell_type": "markdown",
+ "source": [
+ "This functionality is only available for ``brainpylib>=0.2.0``. "
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
{
"cell_type": "markdown",
"metadata": {},
From c6af32cbbd76edb2fafb53b8c4ed887cf18bd0c4 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 19 Nov 2023 13:16:21 +0800
Subject: [PATCH 04/13] update
---
brainpy/_src/mixin.py | 13 -------------
1 file changed, 13 deletions(-)
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index 8ea8a5216..fe7c39940 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -519,19 +519,6 @@ def __subclasscheck__(self, subclass):
return all([issubclass(subclass, cls) for cls in self.__bases__])
-class UnionType2(MixIn):
- """Union type for multiple types.
-
- >>> import brainpy as bp
- >>>
- >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.SupportAutoDelay])
- """
-
- @classmethod
- def __class_getitem__(cls, types: Union[type, Sequence[type]]) -> type:
- return _MetaUnionType('UnionType', types, {})
-
-
if sys.version_info.minor > 8:
class _JointGenericAlias(_UnionGenericAlias, _root=True):
def __subclasscheck__(self, subclass):
From c4f5b328dbd9876bd3a0c6af388f776e9cc2b341 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Mon, 20 Nov 2023 12:37:24 +0800
Subject: [PATCH 05/13] [math] add `brainpy.math.gpu_memory_preallocation()`
for controlling GPU memory preallocation
---
brainpy/_src/math/environment.py | 29 +++++++++++++++++++++++++----
brainpy/math/environment.py | 1 +
2 files changed, 26 insertions(+), 4 deletions(-)
diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py
index eef0361fc..31c264e7d 100644
--- a/brainpy/_src/math/environment.py
+++ b/brainpy/_src/math/environment.py
@@ -702,13 +702,34 @@ def clear_buffer_memory(platform=None):
buf.delete()
-def disable_gpu_memory_preallocation():
- """Disable pre-allocating the GPU memory."""
+def disable_gpu_memory_preallocation(release_memory: bool = True):
+ """Disable pre-allocating the GPU memory.
+
+ This disables the preallocation behavior. JAX will instead allocate GPU memory as needed,
+ potentially decreasing the overall memory usage. However, this behavior is more prone to
+ GPU memory fragmentation, meaning a JAX program that uses most of the available GPU memory
+ may OOM with preallocation disabled.
+
+ Args:
+ release_memory: bool. Whether we release memory during the computation.
+ """
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
- os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
+ if release_memory:
+ os.environ['XLA_PYTHON_CLIENT_ALLOCATOR'] = 'platform'
def enable_gpu_memory_preallocation():
"""Disable pre-allocating the GPU memory."""
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true'
- os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR')
+ os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR', None)
+
+
+def gpu_memory_preallocation(percent: float):
+ """GPU memory allocation.
+
+ If preallocation is enabled, this makes JAX preallocate ``percent`` of the total GPU memory,
+ instead of the default 75%. Lowering the amount preallocated can fix OOMs that occur when the JAX program starts.
+ """
+ assert 0. <= percent < 1., f'GPU memory preallocation must be in [0., 1.]. But we got {percent}.'
+ os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = str(percent)
+
diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py
index a283cc921..d654a0217 100644
--- a/brainpy/math/environment.py
+++ b/brainpy/math/environment.py
@@ -30,6 +30,7 @@
clear_buffer_memory as clear_buffer_memory,
enable_gpu_memory_preallocation as enable_gpu_memory_preallocation,
disable_gpu_memory_preallocation as disable_gpu_memory_preallocation,
+ gpu_memory_preallocation as gpu_memory_preallocation,
ditype as ditype,
dftype as dftype,
)
From ed4ce5fd5b44e50afbfd648f4321329385940c1f Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 26 Nov 2023 10:13:12 +0800
Subject: [PATCH 06/13] [math] `clear_buffer_memory` support to clear array and
compilation both
---
brainpy/_src/math/environment.py | 19 ++++++++++++++++---
1 file changed, 16 insertions(+), 3 deletions(-)
diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py
index 31c264e7d..b7a17bb9e 100644
--- a/brainpy/_src/math/environment.py
+++ b/brainpy/_src/math/environment.py
@@ -9,6 +9,7 @@
import warnings
from typing import Any, Callable, TypeVar, cast
+import jax
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge
@@ -682,7 +683,11 @@ def set_host_device_count(n):
os.environ["XLA_FLAGS"] = " ".join(["--xla_force_host_platform_device_count={}".format(n)] + xla_flags)
-def clear_buffer_memory(platform=None):
+def clear_buffer_memory(
+ platform: str = None,
+ array: bool = True,
+ compilation: bool = False
+):
"""Clear all on-device buffers.
This function will be very useful when you call models in a Python loop,
@@ -697,9 +702,17 @@ def clear_buffer_memory(platform=None):
----------
platform: str
The device to clear its memory.
+ array: bool
+ Clear all buffer array.
+ compilation: bool
+ Clear compilation cache.
+
"""
- for buf in xla_bridge.get_backend(platform=platform).live_buffers():
- buf.delete()
+ if array:
+ for buf in xla_bridge.get_backend(platform=platform).live_buffers():
+ buf.delete()
+ if compilation:
+ jax.clear_caches()
def disable_gpu_memory_preallocation(release_memory: bool = True):
From 8a2beb8404cefaf37b084c8d5cd6c3204f800c4b Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 26 Nov 2023 10:40:24 +0800
Subject: [PATCH 07/13] [dyn] compatible old version of `.reset_state()`
function
---
brainpy/_src/dynsys.py | 52 +++++++++++++++++++++++-------------------
1 file changed, 29 insertions(+), 23 deletions(-)
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 00120a666..10d2de792 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -2,8 +2,8 @@
import collections
import inspect
-import warnings
import numbers
+import warnings
from typing import Union, Dict, Callable, Sequence, Optional, Any
import numpy as np
@@ -13,7 +13,7 @@
from brainpy._src.deprecations import _update_deprecate_msg
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, _get_delay_tool
-from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError
+from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
__all__ = [
@@ -27,9 +27,9 @@
'Dynamic', 'Projection',
]
-
IonChaDyn = None
SLICE_VARS = 'slice_vars'
+the_top_layer_reset_state = True
def not_implemented(fun):
@@ -138,16 +138,12 @@ def update(self, *args, **kwargs):
"""
raise NotImplementedError('Must implement "update" function by subclass self.')
- def reset(self, *args, include_self: bool = False, **kwargs):
+ def reset(self, *args, **kwargs):
"""Reset function which reset the whole variables in the model (including its children models).
``reset()`` function is a collective behavior which resets all states in this model.
See https://brainpy.readthedocs.io/en/latest/tutorial_toolbox/state_resetting.html for details.
-
- Args::
- include_self: bool. Reset states including the node self. Please turn on this if the node has
- implemented its ".reset_state()" function.
"""
from brainpy._src.helpers import reset_state
reset_state(self, *args, **kwargs)
@@ -162,19 +158,6 @@ def reset_state(self, *args, **kwargs):
"""
pass
- # raise APIChangedError(
- # '''
- # From version >= 2.4.6, the policy of ``.reset_state()`` has been changed.
- #
- # 1. If you are resetting all states in a network by calling "net.reset_state()", please use
- # "bp.reset_state(net)" function. ".reset_state()" only defines the resetting of local states
- # in a local node (excluded its children nodes).
- #
- # 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
- #
- # '''
- # )
-
def clear_input(self, *args, **kwargs):
"""Clear the input at the current time step."""
pass
@@ -344,14 +327,37 @@ def _compatible_update(self, *args, **kwargs):
return ret
return update_fun(*args, **kwargs)
+ def _compatible_reset_state(self, *args, **kwargs):
+ global the_top_layer_reset_state
+ the_top_layer_reset_state = False
+ try:
+ self.reset(*args, **kwargs)
+ finally:
+ the_top_layer_reset_state = True
+ warnings.warn(
+ '''
+ From version >= 2.4.6, the policy of ``.reset_state()`` has been changed. See https://brainpy.tech/docs/tutorial_toolbox/state_saving_and_loading.html for details.
+
+ 1. If you are resetting all states in a network by calling "net.reset_state(*args, **kwargs)", please use
+ "bp.reset_state(net, *args, **kwargs)" function, or "net.reset(*args, **kwargs)".
+ ".reset_state()" only defines the resetting of local states in a local node (excluded its children nodes).
+
+ 2. If you does not customize "reset_state()" function for a local node, please implement it in your subclass.
+
+ ''',
+ DeprecationWarning
+ )
+
def _get_update_fun(self):
return object.__getattribute__(self, 'update')
def __getattribute__(self, item):
if item == 'update':
return self._compatible_update # update function compatible with previous ``update()`` function
- else:
- return super().__getattribute__(item)
+ if item == 'reset_state':
+ if the_top_layer_reset_state:
+ return self._compatible_reset_state # reset_state function compatible with previous ``reset_state()`` function
+ return super().__getattribute__(item)
def __repr__(self):
return f'{self.name}(mode={self.mode})'
From 46bb987291a330450967e3d20acbc42abe1d1a78 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sun, 26 Nov 2023 10:41:30 +0800
Subject: [PATCH 08/13] [setup] update installation info
---
setup.py | 13 +++++++++++--
1 file changed, 11 insertions(+), 2 deletions(-)
diff --git a/setup.py b/setup.py
index 69c33cdfe..f867e3078 100644
--- a/setup.py
+++ b/setup.py
@@ -39,7 +39,6 @@
# installation packages
packages = find_packages(exclude=['lib*', 'docs', 'tests'])
-
# setup
setup(
name='brainpy',
@@ -51,13 +50,23 @@
author_email='chao.brain@qq.com',
packages=packages,
python_requires='>=3.8',
- install_requires=['numpy>=1.15', 'jax', 'tqdm', 'msgpack', 'numba'],
+ install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'msgpack', 'numba'],
url='https://github.com/brainpy/BrainPy',
project_urls={
"Bug Tracker": "https://github.com/brainpy/BrainPy/issues",
"Documentation": "https://brainpy.readthedocs.io/",
"Source Code": "https://github.com/brainpy/BrainPy",
},
+ dependency_links=[
+ 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html',
+ ],
+ extras_require={
+ 'cpu': ['jaxlib>=0.4.13', 'brainpylib'],
+ 'cuda': ['jax[cuda]', 'brainpylib-cu11x'],
+ 'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'],
+ 'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'],
+ 'tpu': ['jax[tpu]'],
+ },
keywords=('computational neuroscience, '
'brain-inspired computation, '
'dynamical systems, '
From 6dac69e8f647b98fa3b1cc39023221d7da1064fd Mon Sep 17 00:00:00 2001
From: chaoming
Date: Tue, 28 Nov 2023 14:23:19 +0800
Subject: [PATCH 09/13] [install] upgrade dependency
---
brainpy/_src/dependency_check.py | 34 ++++++++++++++++----------------
brainpy/_src/tools/install.py | 14 +++----------
2 files changed, 20 insertions(+), 28 deletions(-)
diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py
index 33456c02f..ebf6f9404 100644
--- a/brainpy/_src/dependency_check.py
+++ b/brainpy/_src/dependency_check.py
@@ -1,13 +1,13 @@
+import os
+import sys
from jax.lib import xla_client
-
__all__ = [
'import_taichi',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]
-
_minimal_brainpylib_version = '0.1.10'
_minimal_taichi_version = (1, 7, 0)
@@ -15,24 +15,27 @@
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None
+taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. '
+ f'Currently you can install taichi>={_minimal_taichi_version} through:\n\n'
+ '> pip install taichi -U')
+os.environ["TI_LOG_LEVEL"] = "error"
+
def import_taichi():
global taichi
if taichi is None:
- try:
- import taichi as taichi # noqa
- except ModuleNotFoundError:
- raise ModuleNotFoundError(
- 'Taichi is needed. Please install taichi through:\n\n'
- '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
- )
+ with open(os.devnull, 'w') as devnull:
+ old_stdout = sys.stdout
+ sys.stdout = devnull
+ try:
+ import taichi as taichi # noqa
+ except ModuleNotFoundError:
+ raise ModuleNotFoundError(taichi_install_info)
+ finally:
+ sys.stdout = old_stdout
if taichi.__version__ < _minimal_taichi_version:
- raise RuntimeError(
- f'We need taichi>={_minimal_taichi_version}. '
- f'Currently you can install taichi>={_minimal_taichi_version} through taichi-nightly:\n\n'
- '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
- )
+ raise RuntimeError(taichi_install_info)
return taichi
@@ -82,6 +85,3 @@ def import_brainpylib_gpu_ops():
'See https://brainpy.readthedocs.io for installation instructions.')
return brainpylib_gpu_ops
-
-
-
diff --git a/brainpy/_src/tools/install.py b/brainpy/_src/tools/install.py
index aadf0f5c0..4e4a537a9 100644
--- a/brainpy/_src/tools/install.py
+++ b/brainpy/_src/tools/install.py
@@ -8,19 +8,11 @@
BrainPy needs jaxlib, please install it.
-1. If you are using Windows system, install jaxlib through
+1. If you are using brainpy on CPU platform, please install jaxlib through
- >>> pip install jaxlib -f https://whls.blob.core.windows.net/unstable/index.html
+ >>> pip install jaxlib
-2. If you are using macOS platform, install jaxlib through
-
- >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
-
-3. If you are using Linux platform, install jaxlib through
-
- >>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
-
-4. If you are using Linux + CUDA platform, install jaxlib through
+2. If you are using Linux + CUDA platform, install jaxlib through
>>> pip install jaxlib -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
From 6c2c9bb3ce995a427535e48835b20d597d1ff19d Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sat, 2 Dec 2023 14:42:07 +0800
Subject: [PATCH 10/13] updates
---
README.md | 3 +-
brainpy/__init__.py | 4 +-
brainpy/_src/dependency_check.py | 8 +-
brainpy/_src/dyn/projections/plasticity.py | 198 ++++++++++++++++++++-
brainpy/_src/measure/lfp.py | 2 +-
5 files changed, 203 insertions(+), 12 deletions(-)
diff --git a/README.md b/README.md
index 716dbd900..9c74b82d1 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,7 @@
-BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Numba](https://github.com/numba/numba), and other JIT compilers). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
+BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Taichi](https://github.com/taichi-dev/taichi), [Numba](https://github.com/numba/numba), and others). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest
- **Source**: https://github.com/brainpy/BrainPy
@@ -77,6 +77,7 @@ We provide a Binder environment for BrainPy. You can use the following button to
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
+- [《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling)
- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course)
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index 371ed6b27..1342eb9a0 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-__version__ = "2.4.6"
+__version__ = "2.4.6.post2"
# fundamental supporting modules
from brainpy import errors, check, tools
@@ -75,7 +75,7 @@
)
NeuGroup = NeuGroupNS = dyn.NeuDyn
-# shared parameters
+# common tools
from brainpy._src.context import (share as share)
from brainpy._src.helpers import (reset_state as reset_state,
save_state as save_state,
diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py
index ebf6f9404..e8492f826 100644
--- a/brainpy/_src/dependency_check.py
+++ b/brainpy/_src/dependency_check.py
@@ -15,9 +15,9 @@
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None
-taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. '
- f'Currently you can install taichi>={_minimal_taichi_version} through:\n\n'
- '> pip install taichi -U')
+taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
+ f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
+ '> pip install taichi==1.7.0 -U')
os.environ["TI_LOG_LEVEL"] = "error"
@@ -34,7 +34,7 @@ def import_taichi():
finally:
sys.stdout = old_stdout
- if taichi.__version__ < _minimal_taichi_version:
+ if taichi.__version__ != _minimal_taichi_version:
raise RuntimeError(taichi_install_info)
return taichi
diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py
index 3ee6f4fef..3fb3c1232 100644
--- a/brainpy/_src/dyn/projections/plasticity.py
+++ b/brainpy/_src/dyn/projections/plasticity.py
@@ -49,8 +49,8 @@ class STDP_Song2000(Projection):
\begin{aligned}
\frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
- \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\
- \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\
+ \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\
+ \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\
\end{aligned}
where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
@@ -64,8 +64,8 @@ class STDP_Song2000(Projection):
class STDPNet(bp.DynamicalSystem):
def __init__(self, num_pre, num_post):
super().__init__()
- self.pre = bp.dyn.LifRef(num_pre, name='neu1')
- self.post = bp.dyn.LifRef(num_post, name='neu2')
+ self.pre = bp.dyn.LifRef(num_pre)
+ self.post = bp.dyn.LifRef(num_post)
self.syn = bp.dyn.STDP_Song2000(
pre=self.pre,
delay=1.,
@@ -219,3 +219,193 @@ def update(self):
return current
+# class PairedSTDP(Projection):
+# r"""Paired spike-time-dependent plasticity model.
+#
+# This model filters the synaptic currents according to the variables: :math:`w`.
+#
+# .. math::
+#
+# I_{syn}^+(t) = I_{syn}^-(t) * w
+#
+# where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
+# and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse,
+# the conductance of the synapse will increase w.
+#
+# The dynamics of :math:`w` is governed by the following equation:
+#
+# .. math::
+#
+# \begin{aligned}
+# \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
+# \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\
+# \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\
+# \end{aligned}
+#
+# where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
+# of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike.
+#
+# Here is an example of the usage of this class::
+#
+# import brainpy as bp
+# import brainpy.math as bm
+#
+# class STDPNet(bp.DynamicalSystem):
+# def __init__(self, num_pre, num_post):
+# super().__init__()
+# self.pre = bp.dyn.LifRef(num_pre)
+# self.post = bp.dyn.LifRef(num_post)
+# self.syn = bp.dyn.STDP_Song2000(
+# pre=self.pre,
+# delay=1.,
+# comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
+# weight=bp.init.Uniform(max_val=0.1)),
+# syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
+# out=bp.dyn.COBA.desc(E=0.),
+# post=self.post,
+# tau_s=16.8,
+# tau_t=33.7,
+# A1=0.96,
+# A2=0.53,
+# )
+#
+# def update(self, I_pre, I_post):
+# self.syn()
+# self.pre(I_pre)
+# self.post(I_post)
+# conductance = self.syn.refs['syn'].g
+# Apre = self.syn.refs['pre_trace'].g
+# Apost = self.syn.refs['post_trace'].g
+# current = self.post.sum_inputs(self.post.V)
+# return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
+#
+# duration = 300.
+# I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
+# [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255])
+# I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
+# [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])
+#
+# net = STDPNet(1, 1)
+# def run(i, I_pre, I_post):
+# pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
+# return pre_spike, post_spike, g, Apre, Apost, current, W
+#
+# indices = bm.arange(0, duration, bm.dt)
+# pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post])
+#
+# Args:
+# tau_s: float. The time constant of :math:`A_{pre}`.
+# tau_t: float. The time constant of :math:`A_{post}`.
+# A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value.
+# A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value.
+# W_max: float. The maximum weight.
+# W_min: float. The minimum weight.
+# pre: DynamicalSystem. The pre-synaptic neuron group.
+# delay: int, float. The pre spike delay length. (ms)
+# syn: DynamicalSystem. The synapse model.
+# comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers.
+# out: DynamicalSystem. The synaptic current output models.
+# post: DynamicalSystem. The post-synaptic neuron group.
+# out_label: str. The output label.
+# name: str. The model name.
+# """
+#
+# def __init__(
+# self,
+# pre: JointType[DynamicalSystem, SupportAutoDelay],
+# delay: Union[None, int, float],
+# syn: ParamDescriber[DynamicalSystem],
+# comm: JointType[DynamicalSystem, SupportSTDP],
+# out: ParamDescriber[JointType[DynamicalSystem, BindCondData]],
+# post: DynamicalSystem,
+# # synapse parameters
+# tau_s: float = 16.8,
+# tau_t: float = 33.7,
+# lambda_: float = 0.96,
+# alpha: float = 0.53,
+# mu: float = 0.53,
+# W_max: Optional[float] = None,
+# W_min: Optional[float] = None,
+# # others
+# out_label: Optional[str] = None,
+# name: Optional[str] = None,
+# mode: Optional[bm.Mode] = None,
+# ):
+# super().__init__(name=name, mode=mode)
+#
+# # synaptic models
+# check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
+# check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP])
+# check.is_instance(syn, ParamDescriber[DynamicalSystem])
+# check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
+# check.is_instance(post, DynamicalSystem)
+# self.pre_num = pre.num
+# self.post_num = post.num
+# self.comm = comm
+# self._is_align_post = issubclass(syn.cls, AlignPost)
+#
+# # delay initialization
+# delay_cls = register_delay_by_return(pre)
+# delay_cls.register_entry(self.name, delay)
+#
+# # synapse and output initialization
+# if self._is_align_post:
+# syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post,
+# proj_name=self.name)
+# else:
+# syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre')
+# out_cls = out()
+# add_inp_fun(out_label, self.name, out_cls, post)
+#
+# # references
+# self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
+# self.refs['delay'] = delay_cls
+# self.refs['syn'] = syn_cls # invisible to ``self.node()``
+# self.refs['out'] = out_cls # invisible to ``self.node()``
+# self.refs['comm'] = comm
+#
+# # tracing pre-synaptic spikes using Exponential model
+# self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s))
+#
+# # tracing post-synaptic spikes using Exponential model
+# self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t))
+#
+# # synapse parameters
+# self.W_max = W_max
+# self.W_min = W_min
+# self.tau_s = tau_s
+# self.tau_t = tau_t
+# self.A1 = A1
+# self.A2 = A2
+#
+# def update(self):
+# # pre-synaptic spikes
+# pre_spike = self.refs['delay'].at(self.name) # spike
+# # pre-synaptic variables
+# if self._is_align_post:
+# # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance
+# x = pre_spike
+# else:
+# # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance
+# x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable
+#
+# # post spikes
+# if not hasattr(self.refs['post'], 'spike'):
+# raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.')
+# post_spike = self.refs['post'].spike
+#
+# # weight updates
+# Apost = self.refs['post_trace'].g
+# self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max)
+# Apre = self.refs['pre_trace'].g
+# self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max)
+#
+# # synaptic currents
+# current = self.comm(x)
+# if self._is_align_post:
+# self.refs['syn'].add_current(current) # synapse post current
+# else:
+# self.refs['out'].bind_cond(current) # align pre
+# return current
+
+
diff --git a/brainpy/_src/measure/lfp.py b/brainpy/_src/measure/lfp.py
index 0662be8d9..434666efb 100644
--- a/brainpy/_src/measure/lfp.py
+++ b/brainpy/_src/measure/lfp.py
@@ -10,7 +10,7 @@
]
-def unitary_LFP(times, spikes, spike_type='exc',
+def unitary_LFP(times, spikes, spike_type,
xmax=0.2, ymax=0.2, va=200., lambda_=0.2,
sig_i=2.1, sig_e=2.1 * 1.5, location='soma layer', seed=None):
"""A kernel-based method to calculate unitary local field potentials (uLFP)
From 670937e95c800f9f732f2ee709723b0e966835fd Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sat, 2 Dec 2023 16:11:28 +0800
Subject: [PATCH 11/13] [math] add `brainpy.math.defjvp`, support to define jvp
rules for Primitive with multiple results. See examples in
`test_ad_support.py`
---
brainpy/_src/math/ad_support.py | 50 ++++++++
brainpy/_src/math/op_register/base.py | 4 +-
brainpy/_src/math/tests/test_ad_support.py | 136 +++++++++++++++++++++
brainpy/math/others.py | 4 +
4 files changed, 192 insertions(+), 2 deletions(-)
create mode 100644 brainpy/_src/math/ad_support.py
create mode 100644 brainpy/_src/math/tests/test_ad_support.py
diff --git a/brainpy/_src/math/ad_support.py b/brainpy/_src/math/ad_support.py
new file mode 100644
index 000000000..fb710a675
--- /dev/null
+++ b/brainpy/_src/math/ad_support.py
@@ -0,0 +1,50 @@
+import functools
+from functools import partial
+
+from jax import tree_util
+from jax.core import Primitive
+from jax.interpreters import ad
+from brainpy._src.math.op_register.base import XLACustomOp
+
+__all__ = [
+ 'defjvp',
+]
+
+
+def defjvp(primitive, *jvp_rules):
+ """Define JVP rule when the primitive
+
+ Args:
+ primitive: Primitive, XLACustomOp.
+ *jvp_rules: The JVP translation rule for each primal.
+
+ Returns:
+ The JVP gradients.
+ """
+ if isinstance(primitive, XLACustomOp):
+ primitive = primitive.primitive
+ assert isinstance(primitive, Primitive)
+ if primitive.multiple_results:
+ ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
+ else:
+ ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive)
+
+
+def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params):
+ assert primitive.multiple_results
+ val_out = tuple(primitive.bind(*primals, **params))
+ tree = tree_util.tree_structure(val_out)
+ tangents_out = []
+ for rule, t in zip(jvp_rules, tangents):
+ if rule is not None and type(t) is not ad.Zero:
+ r = tuple(rule(t, *primals, **params))
+ tangents_out.append(r)
+ assert tree_util.tree_structure(r) == tree
+ return val_out, functools.reduce(_add_tangents,
+ tangents_out,
+ tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out))
+
+
+def _add_tangents(xs, ys):
+ return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero))
+
diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py
index 31aef70d6..6def88950 100644
--- a/brainpy/_src/math/op_register/base.py
+++ b/brainpy/_src/math/op_register/base.py
@@ -139,13 +139,13 @@ def __init__(
if transpose_translation is not None:
ad.primitive_transposes[self.primitive] = transpose_translation
- def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None):
+ def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs):
if outs is None:
outs = self.outs
assert outs is not None
outs = tuple([_transform_to_shapedarray(o) for o in outs])
ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array)
- return self.primitive.bind(*ins, outs=outs)
+ return self.primitive.bind(*ins, outs=outs, **kwargs)
def def_abstract_eval(self, fun):
"""Define the abstract evaluation function.
diff --git a/brainpy/_src/math/tests/test_ad_support.py b/brainpy/_src/math/tests/test_ad_support.py
new file mode 100644
index 000000000..66b8418b8
--- /dev/null
+++ b/brainpy/_src/math/tests/test_ad_support.py
@@ -0,0 +1,136 @@
+from typing import Tuple
+
+import jax
+import numba
+from jax import core
+from jax import numpy as jnp
+from jax.interpreters import ad
+
+import brainpy as bp
+import brainpy.math as bm
+
+
+def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ):
+ data = jnp.atleast_1d(bm.as_jax(data))
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ vector = bm.as_jax(vector)
+ if vector.dtype == jnp.bool_:
+ vector = bm.as_jax(vector, dtype=data.dtype)
+ outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)]
+ if transpose:
+ return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
+ else:
+ return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose)
+
+
+@numba.njit(fastmath=True)
+def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val):
+ res_val.fill(0)
+ if values.shape[0] == 1:
+ values = values[0]
+ for row_i in range(vector.shape[0]):
+ v = vector[row_i]
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ res_val[col_indices[j]] += values * v
+ else:
+ for row_i in range(vector.shape[0]):
+ v = vector[row_i]
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ res_val[col_indices[j]] += v * values[j]
+
+
+@numba.njit(fastmath=True, parallel=True, nogil=True)
+def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val):
+ res_val.fill(0)
+ # csr mat @ vec
+ if values.shape[0] == 1:
+ values = values[0]
+ for row_i in numba.prange(res_val.shape[0]):
+ r = 0.
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ r += values * vector[col_indices[j]]
+ res_val[row_i] = r
+ else:
+ for row_i in numba.prange(res_val.shape[0]):
+ r = 0.
+ for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
+ r += values[j] * vector[col_indices[j]]
+ res_val[row_i] = r
+
+
+def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
+ return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose)
+
+
+def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs):
+ return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose)
+
+
+def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs):
+ if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
+ raise ValueError("Cannot transpose with respect to sparse indices.")
+
+ ct = ct[0]
+ if ad.is_undefined_primal(vector):
+ ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose)
+ return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
+
+ else:
+ if type(ct) is ad.Zero:
+ ct_data = ad.Zero(data)
+ else:
+ if data.aval.shape[0] == 1: # scalar
+ ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)
+ ct_data = jnp.inner(ct, ct_data)
+ else: # heterogeneous values
+ row, col = bm.sparse.csr_to_coo(indices, indptr)
+ ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row]
+ return ct_data, indices, indptr, vector
+
+
+prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp)
+bm.defjvp(prim_trans, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
+prim_trans.def_transpose_rule(_csrmv_cusparse_transpose)
+
+prim = bm.XLACustomOp(_csr_matvec_numba_imp)
+bm.defjvp(prim, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
+prim.def_transpose_rule(_csrmv_cusparse_transpose)
+
+
+def sum_op(op):
+ def func(*args, **kwargs):
+ r = op(*args, **kwargs)
+ return r.sum()
+
+ return func
+
+
+def try_a_trial(transpose, shape):
+ rng = bm.random.RandomState()
+ conn = bp.conn.FixedProb(0.1)
+ indices, indptr = conn(*shape).require('pre2post')
+ indices = bm.as_jax(indices)
+ indptr = bm.as_jax(indptr)
+ heter_data = rng.random(indices.shape)
+ heter_data = bm.as_jax(heter_data)
+ vector = rng.random(shape[0] if transpose else shape[1])
+ vector = bm.as_jax(vector)
+
+ r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs, method='vector')), argnums=(0, 3))(
+ heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
+ r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))(
+ heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
+ print(r5)
+ print(r6)
+ assert bm.allclose(r5[0], r6[0])
+ assert bm.allclose(r5[1], r6[1][0])
+
+
+def test():
+ transposes = [True, False]
+ shapes = [(100, 200), (10, 1000), (2, 2000)]
+
+ for transpose in transposes:
+ for shape in shapes:
+ try_a_trial(transpose, shape)
diff --git a/brainpy/math/others.py b/brainpy/math/others.py
index 23d9b0816..d1108d1fa 100644
--- a/brainpy/math/others.py
+++ b/brainpy/math/others.py
@@ -9,3 +9,7 @@
from brainpy._src.math.object_transform.naming import (
clear_name_cache,
)
+
+from brainpy._src.math.ad_support import (
+ defjvp as defjvp,
+)
From f45e635f0ed35a0c885fc1a47b78caa902976d04 Mon Sep 17 00:00:00 2001
From: chaoming
Date: Sat, 2 Dec 2023 16:17:29 +0800
Subject: [PATCH 12/13] [math] add `brainpy.math.XLACustomOp.defjvp`
---
brainpy/_src/math/{ => op_register}/ad_support.py | 3 ---
brainpy/_src/math/op_register/base.py | 15 +++++++++++----
.../{ => op_register}/tests/test_ad_support.py | 4 ++--
brainpy/math/op_register.py | 3 +--
brainpy/math/others.py | 4 ----
5 files changed, 14 insertions(+), 15 deletions(-)
rename brainpy/_src/math/{ => op_register}/ad_support.py (91%)
rename brainpy/_src/math/{ => op_register}/tests/test_ad_support.py (97%)
diff --git a/brainpy/_src/math/ad_support.py b/brainpy/_src/math/op_register/ad_support.py
similarity index 91%
rename from brainpy/_src/math/ad_support.py
rename to brainpy/_src/math/op_register/ad_support.py
index fb710a675..0e50091f2 100644
--- a/brainpy/_src/math/ad_support.py
+++ b/brainpy/_src/math/op_register/ad_support.py
@@ -4,7 +4,6 @@
from jax import tree_util
from jax.core import Primitive
from jax.interpreters import ad
-from brainpy._src.math.op_register.base import XLACustomOp
__all__ = [
'defjvp',
@@ -21,8 +20,6 @@ def defjvp(primitive, *jvp_rules):
Returns:
The JVP gradients.
"""
- if isinstance(primitive, XLACustomOp):
- primitive = primitive.primitive
assert isinstance(primitive, Primitive)
if primitive.multiple_results:
ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive)
diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py
index 6def88950..cb05ece81 100644
--- a/brainpy/_src/math/op_register/base.py
+++ b/brainpy/_src/math/op_register/base.py
@@ -14,11 +14,10 @@
# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
- register_taichi_gpu_translation_rule,
- encode_md5,
- _preprocess_kernel_call_cpu,
- get_source_with_dependencies)
+ register_taichi_gpu_translation_rule,)
from .utils import register_general_batching
+from brainpy._src.math.op_register.ad_support import defjvp
+
__all__ = [
'XLACustomOp',
@@ -171,6 +170,14 @@ def def_jvp_rule(self, fun):
"""
ad.primitive_jvps[self.primitive] = fun
+ def defjvp(self, *jvp_rules):
+ """Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``, but supports the Primitive with multiple results.
+
+ Args:
+ jvp_rules: The JVP rules.
+ """
+ defjvp(self.primitive, *jvp_rules)
+
def def_transpose_rule(self, fun):
"""Define the transpose rule.
diff --git a/brainpy/_src/math/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py
similarity index 97%
rename from brainpy/_src/math/tests/test_ad_support.py
rename to brainpy/_src/math/op_register/tests/test_ad_support.py
index 66b8418b8..547bbdc7c 100644
--- a/brainpy/_src/math/tests/test_ad_support.py
+++ b/brainpy/_src/math/op_register/tests/test_ad_support.py
@@ -90,11 +90,11 @@ def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, trans
prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp)
-bm.defjvp(prim_trans, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
+prim_trans.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
prim_trans.def_transpose_rule(_csrmv_cusparse_transpose)
prim = bm.XLACustomOp(_csr_matvec_numba_imp)
-bm.defjvp(prim, _csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
+prim.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec)
prim.def_transpose_rule(_csrmv_cusparse_transpose)
diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py
index b30ce4414..014a54e6f 100644
--- a/brainpy/math/op_register.py
+++ b/brainpy/math/op_register.py
@@ -6,8 +6,7 @@
compile_cpu_signature_with_numba,
)
-
from brainpy._src.math.op_register.base import XLACustomOp
-
+from brainpy._src.math.op_register.ad_support import defjvp
diff --git a/brainpy/math/others.py b/brainpy/math/others.py
index d1108d1fa..23d9b0816 100644
--- a/brainpy/math/others.py
+++ b/brainpy/math/others.py
@@ -9,7 +9,3 @@
from brainpy._src.math.object_transform.naming import (
clear_name_cache,
)
-
-from brainpy._src.math.ad_support import (
- defjvp as defjvp,
-)
From ffeb9cbdb6fac32e9968eb2631145690cfe2bb6a Mon Sep 17 00:00:00 2001
From: chaoming
Date: Mon, 4 Dec 2023 14:08:34 +0800
Subject: [PATCH 13/13] [doc] upgrade `brainpy.math.defjvp` docstring
---
brainpy/_src/math/op_register/ad_support.py | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py
index 0e50091f2..f7bf9554a 100644
--- a/brainpy/_src/math/op_register/ad_support.py
+++ b/brainpy/_src/math/op_register/ad_support.py
@@ -11,14 +11,19 @@
def defjvp(primitive, *jvp_rules):
- """Define JVP rule when the primitive
+ """Define JVP rules for any JAX primitive.
+
+ This function is similar to ``jax.interpreters.ad.defjvp``.
+ However, the JAX one only supports primitive with ``multiple_results=False``.
+ ``brainpy.math.defjvp`` enables to define the independent JVP rule for
+ each input parameter no matter ``multiple_results=False/True``.
+
+ For examples, please see ``test_ad_support.py``.
+
Args:
primitive: Primitive, XLACustomOp.
*jvp_rules: The JVP translation rule for each primal.
-
- Returns:
- The JVP gradients.
"""
assert isinstance(primitive, Primitive)
if primitive.multiple_results: