diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
new file mode 100644
index 000000000..d547100d3
--- /dev/null
+++ b/.github/workflows/docker.yml
@@ -0,0 +1,59 @@
+name: Docker
+
+on:
+ release:
+ types: [published]
+ pull_request:
+ paths:
+ - docker/**
+ - .github/workflows/docker.yml
+
+
+jobs:
+ docker-build-push:
+ if: |
+ github.repository_owner == 'brainpy' ||
+ github.event_name != 'release'
+ runs-on: ubuntu-22.04
+ strategy:
+ matrix:
+ include:
+ - context: "docker/"
+ base: "brainpy/brainpy"
+ env:
+ TARGET_PLATFORMS: linux/amd64
+ REGISTRY: ghcr.io
+ IMAGE_NAME: ${{ github.repository }}
+ DOCKER_TAG_NAME: |
+ ${{
+ (github.event_name == 'release' && github.event.release.tag_name) ||
+ 'pull-request-test'
+ }}
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Login to DockerHub
+ if: github.event_name != 'pull_request'
+ uses: docker/login-action@v2
+ with:
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+
+ - name: Docker Build & Push (version tag)
+ uses: docker/build-push-action@v4
+ with:
+ context: ${{ matrix.context }}
+ tags: ${{ matrix.base }}:${{ env.DOCKER_TAG_NAME }}
+ push: ${{ github.event_name != 'pull_request' }}
+ platforms: ${{ env.TARGET_PLATFORMS }}
+
+ - name: Docker Build & Push (latest tag)
+ if: |
+ (github.event_name == 'release' && ! github.event.release.prerelease)
+ uses: docker/build-push-action@v4
+ with:
+ context: ${{ matrix.context }}
+ tags: ${{ matrix.base }}:latest
+ push: ${{ github.event_name != 'pull_request' }}
+ platforms: ${{ env.TARGET_PLATFORMS }}
\ No newline at end of file
diff --git a/.gitignore b/.gitignore
index dec4fa91d..29424003d 100644
--- a/.gitignore
+++ b/.gitignore
@@ -225,3 +225,4 @@ cython_debug/
/docs/tutorial_advanced/data/
/my_tests/
/examples/dynamics_simulation/Joglekar_2018_data/
+/docs/apis/deprecated/generated/
diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md
new file mode 100644
index 000000000..caf968c4a
--- /dev/null
+++ b/ACKNOWLEDGMENTS.md
@@ -0,0 +1,13 @@
+# Acknowledgments
+
+The development of BrainPy is being or has been supported by many organizations, programs, and individuals since 2020.
+The following list of support received is therefore necessarily incomplete.
+
+
+This project has received funding from Science and Technology Innovation 2030 (China Brain Project):
+
+- Brain Science and Brain-inspired Intelligence Project (No. 2021ZD0200204).
+
+Additionally, BrainPy gratefully acknowledges the support and funding received from:
+
+- Beijing Academy of Artificial Intelligence.
diff --git a/README.md b/README.md
index 8d98ceaab..263d74568 100644
--- a/README.md
+++ b/README.md
@@ -34,22 +34,37 @@ $ pip install brainpy brainpylib -U
For detailed installation instructions, please refer to the documentation: [Quickstart/Installation](https://brainpy.readthedocs.io/en/latest/quickstart/installation.html)
+### Using BrainPy with docker
+
+We provide a docker image for BrainPy. You can use the following command to pull the image:
+```bash
+$ docker pull brainpy/brainpy:latest
+```
+
+Then, you can run the image with the following command:
+```bash
+$ docker run -it --platform linux/amd64 brainpy/brainpy:latest
+```
+
+### Using BrainPy with Binder
+
+We provide a Binder environment for BrainPy. You can use the following button to launch the environment:
+
+[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main)
+
## Ecosystem
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
-## Citing and Funding
-
-If you are using ``brainpy``, please consider citing [the corresponding papers](https://brainpy.readthedocs.io/en/latest/tutorial_FAQs/citing_and_publication.html).
+## Citing
BrainPy is developed by a team in Neural Information Processing Lab at Peking University, China.
Our team is committed to the long-term maintenance and development of the project.
-Moreover, the development of BrainPy is being or has been supported by Science and Technology
-Innovation 2030 - Brain Science and Brain-inspired Intelligence Project (China Brain Project),
-and Beijing Academy of Artificial Intelligence.
+If you are using ``brainpy``, please consider citing [the corresponding papers](https://brainpy.readthedocs.io/en/latest/tutorial_FAQs/citing_and_publication.html).
+
## Ongoing development plans
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index c31989a2a..97f5aa304 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-__version__ = "2.4.4.post3"
+__version__ = "2.4.4.post4"
# fundamental supporting modules
from brainpy import errors, check, tools
diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py
index 9b9e7bf01..8ffdc05e6 100644
--- a/brainpy/_src/delay.py
+++ b/brainpy/_src/delay.py
@@ -327,7 +327,7 @@ def retrieve(self, delay_step, *indices):
if self.method == ROTATE_UPDATE:
i = share.load('i')
- delay_idx = bm.as_jax((delay_step - i - 1) % self.max_length)
+ delay_idx = bm.as_jax((delay_step - i - 1) % self.max_length, dtype=jnp.int32)
delay_idx = jax.lax.stop_gradient(delay_idx)
elif self.method == CONCAT_UPDATE:
@@ -358,7 +358,7 @@ def update(
# update the delay data at the rotation index
if self.method == ROTATE_UPDATE:
i = share.load('i')
- idx = bm.as_jax((-i - 1) % self.max_length)
+ idx = bm.as_jax((-i - 1) % self.max_length, dtype=jnp.int32)
self.data[idx] = latest_value
# update the delay data at the first position
diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py
index 3bdc3a31c..5fdee8d99 100644
--- a/brainpy/_src/dnn/linear.py
+++ b/brainpy/_src/dnn/linear.py
@@ -10,12 +10,12 @@
from brainpy import math as bm
from brainpy._src import connect, initialize as init
from brainpy._src.context import share
-from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm
from brainpy.check import is_initializer
from brainpy.errors import MathError
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding
from brainpy._src.dnn.base import Layer
+from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
__all__ = [
'Dense', 'Linear',
@@ -29,14 +29,14 @@
]
-class Dense(Layer):
+class Dense(Layer, SupportOnline, SupportOffline, SupportSTDP):
r"""A linear transformation applied over the last dimension of the input.
Mathematically, this node can be defined as:
.. math::
- y = x \cdot W + b
+ y = x \cdot weight + b
Parameters
----------
@@ -44,7 +44,7 @@ class Dense(Layer):
The number of the input feature. A positive integer.
num_out: int
The number of the output features. A positive integer.
- W_initializer: optional, Initializer
+ weight_initializer: optional, Initializer
The weight initialization.
b_initializer: optional, Initializer
The bias initialization.
@@ -52,12 +52,6 @@ class Dense(Layer):
Enable training this node or not. (default True)
"""
- online_fit_by: Optional[OnlineAlgorithm]
- '''Online fitting method.'''
-
- offline_fit_by: Optional[OfflineAlgorithm]
- '''Offline fitting method.'''
-
def __init__(
self,
num_in: int,
@@ -80,13 +74,13 @@ def __init__(
f'a positive integer. Received: num_out={num_out}')
# weight initializer
- self.weight_initializer = W_initializer
+ self.W_initializer = W_initializer
self.bias_initializer = b_initializer
is_initializer(W_initializer, 'weight_initializer')
is_initializer(b_initializer, 'bias_initializer', allow_none=True)
# parameter initialization
- W = parameter(self.weight_initializer, (num_in, self.num_out))
+ W = parameter(self.W_initializer, (num_in, self.num_out))
b = parameter(self.bias_initializer, (self.num_out,))
if isinstance(self.mode, bm.TrainingMode):
W = bm.TrainVar(W)
@@ -95,8 +89,8 @@ def __init__(
self.b = b
# fitting parameters
- self.online_fit_by = None
- self.offline_fit_by = None
+ self.online_fit_by = None # support online training
+ self.offline_fit_by = None # support offline training
self.fit_record = dict()
def __repr__(self):
@@ -204,6 +198,20 @@ def offline_fit(self,
self.W.value = Wff
self.b.value = bias[0]
+ def update_STDP(self, dW, constraints=None):
+ if isinstance(self.W, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
+ raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
+ if self.W.shape != dW.shape:
+ raise ValueError(f'The shape of delta_weight {dW.shape} '
+ f'should be the same as the shape of weight {self.W.shape}.')
+ if not isinstance(self.W, bm.Variable):
+ self.tracing_variable('W', self.W, self.W.shape)
+ self.W += dW
+ if constraints is not None:
+ self.W.value = constraints(self.W)
+
Linear = Dense
@@ -219,7 +227,7 @@ def update(self, x):
return x
-class AllToAll(Layer):
+class AllToAll(Layer, SupportSTDP):
"""Synaptic matrix multiplication with All2All connections.
Args:
@@ -281,8 +289,23 @@ def update(self, pre_val):
post_val = pre_val @ self.weight
return post_val
+ def update_STDP(self, dW, constraints=None):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
+ raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
+ if self.weight.shape != dW.shape:
+ raise ValueError(f'The shape of delta_weight {dW.shape} '
+ f'should be the same as the shape of weight {self.weight.shape}.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ self.weight += dW
+ if constraints is not None:
+ self.weight.value = constraints(self.weight)
+
+
-class OneToOne(Layer):
+class OneToOne(Layer, SupportSTDP):
"""Synaptic matrix multiplication with One2One connection.
Args:
@@ -315,8 +338,23 @@ def __init__(
def update(self, pre_val):
return pre_val * self.weight
-
-class MaskedLinear(Layer):
+ def update_STDP(self, dW, constraints=None):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
+ raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
+ dW = dW.sum(axis=0)
+ if self.weight.shape != dW.shape:
+ raise ValueError(f'The shape of delta_weight {dW.shape} '
+ f'should be the same as the shape of weight {self.weight.shape}.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ self.weight += dW
+ if constraints is not None:
+ self.weight.value = constraints(self.weight)
+
+
+class MaskedLinear(Layer, SupportSTDP):
r"""Synaptic matrix multiplication with masked dense computation.
It performs the computation of:
@@ -369,8 +407,23 @@ def __init__(
def update(self, x):
return x @ self.mask_fun(self.weight * self.mask)
+ def update_STDP(self, dW, constraints=None):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
+ raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
+ if self.weight.shape != dW.shape:
+ raise ValueError(f'The shape of delta_weight {dW.shape} '
+ f'should be the same as the shape of weight {self.weight.shape}.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+
+ self.weight += dW
+ if constraints is not None:
+ self.weight.value = constraints(self.weight)
+
-class CSRLinear(Layer):
+class CSRLinear(Layer, SupportSTDP):
r"""Synaptic matrix multiplication with CSR sparse computation.
It performs the computation of:
@@ -438,6 +491,22 @@ def _batch_csrmv(self, x):
transpose=self.transpose,
method=self.method)
+ def update_STDP(self, dW, constraints=None):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
+ raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
+ pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr)
+ sparse_dW = dW[pre_ids, post_ids]
+ if self.weight.shape != sparse_dW.shape:
+ raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} '
+ f'should be the same as the shape of sparse weight {self.weight.shape}.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ self.weight += sparse_dW
+ if constraints is not None:
+ self.weight.value = constraints(self.weight)
+
class CSCLinear(Layer):
r"""Synaptic matrix multiplication with CSC sparse computation.
@@ -474,7 +543,7 @@ def __init__(
self.sharding = sharding
-class EventCSRLinear(Layer):
+class EventCSRLinear(Layer, SupportSTDP):
r"""Synaptic matrix multiplication with event CSR sparse computation.
It performs the computation of:
@@ -538,6 +607,22 @@ def _batch_csrmv(self, x):
shape=(self.conn.pre_num, self.conn.post_num),
transpose=self.transpose)
+ def update_STDP(self, dW, constraints=None):
+ if isinstance(self.weight, float):
+ raise ValueError(f'Cannot update the weight of a constant node.')
+ if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
+ raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
+ pre_ids, post_ids = bm.sparse.csr_to_coo(self.indices, self.indptr)
+ sparse_dW = dW[pre_ids, post_ids]
+ if self.weight.shape != sparse_dW.shape:
+ raise ValueError(f'The shape of sparse delta_weight {sparse_dW.shape} '
+ f'should be the same as the shape of sparse weight {self.weight.shape}.')
+ if not isinstance(self.weight, bm.Variable):
+ self.tracing_variable('weight', self.weight, self.weight.shape)
+ self.weight += sparse_dW
+ if constraints is not None:
+ self.weight.value = constraints(self.weight)
+
class BcsrMM(Layer):
r"""Synaptic matrix multiplication with BCSR sparse computation.
diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py
index e318eee4b..e18ac2a82 100644
--- a/brainpy/_src/dyn/base.py
+++ b/brainpy/_src/dyn/base.py
@@ -1,19 +1,19 @@
# -*- coding: utf-8 -*-
from brainpy._src.dynsys import Dynamic
-from brainpy._src.mixin import AutoDelaySupp, ParamDesc
+from brainpy._src.mixin import SupportAutoDelay, ParamDesc
__all__ = [
'NeuDyn', 'SynDyn', 'IonChaDyn',
]
-class NeuDyn(Dynamic, AutoDelaySupp):
+class NeuDyn(Dynamic, SupportAutoDelay):
"""Neuronal Dynamics."""
pass
-class SynDyn(Dynamic, AutoDelaySupp, ParamDesc):
+class SynDyn(Dynamic, SupportAutoDelay, ParamDesc):
"""Synaptic Dynamics."""
pass
diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py
index 2dfa2dd14..23b907286 100644
--- a/brainpy/_src/dyn/projections/aligns.py
+++ b/brainpy/_src/dyn/projections/aligns.py
@@ -4,7 +4,7 @@
from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo,
- AutoDelaySupp, BindCondData, AlignPost)
+ SupportAutoDelay, BindCondData, AlignPost)
__all__ = [
'VanillaProj',
@@ -297,7 +297,7 @@ def update(self, inp):
def __init__(
self,
- pre: JointType[DynamicalSystem, AutoDelaySupp],
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
comm: DynamicalSystem,
syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]],
@@ -310,7 +310,7 @@ def __init__(
super().__init__(name=name, mode=mode)
# synaptic models
- check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(comm, DynamicalSystem)
check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]])
check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
@@ -507,7 +507,7 @@ def update(self, inp):
def __init__(
self,
- pre: JointType[DynamicalSystem, AutoDelaySupp],
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
comm: DynamicalSystem,
syn: JointType[DynamicalSystem, AlignPost],
@@ -520,7 +520,7 @@ def __init__(
super().__init__(name=name, mode=mode)
# synaptic models
- check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(comm, DynamicalSystem)
check.is_instance(syn, JointType[DynamicalSystem, AlignPost])
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
@@ -631,7 +631,7 @@ def update(self, inp):
def __init__(
self,
pre: DynamicalSystem,
- syn: ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]],
+ syn: ParamDescInit[JointType[DynamicalSystem, SupportAutoDelay]],
delay: Union[None, int, float],
comm: DynamicalSystem,
out: JointType[DynamicalSystem, BindCondData],
@@ -644,7 +644,7 @@ def __init__(
# synaptic models
check.is_instance(pre, DynamicalSystem)
- check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]])
+ check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, SupportAutoDelay]])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
check.is_instance(post, DynamicalSystem)
@@ -654,7 +654,7 @@ def __init__(
self._syn_id = f'{syn.identifier} // Delay'
if not pre.has_aft_update(self._syn_id):
# "syn_cls" needs an instance of "ProjAutoDelay"
- syn_cls: AutoDelaySupp = syn()
+ syn_cls: SupportAutoDelay = syn()
delay_cls = init_delay_by_return(syn_cls.return_info())
# add to "after_updates"
pre.add_aft_update(self._syn_id, _AlignPre(syn_cls, delay_cls))
@@ -755,7 +755,7 @@ def update(self, inp):
def __init__(
self,
- pre: JointType[DynamicalSystem, AutoDelaySupp],
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
syn: ParamDescInit[DynamicalSystem],
comm: DynamicalSystem,
@@ -768,7 +768,7 @@ def __init__(
super().__init__(name=name, mode=mode)
# synaptic models
- check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, ParamDescInit[DynamicalSystem])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
@@ -884,7 +884,7 @@ def update(self, inp):
def __init__(
self,
pre: DynamicalSystem,
- syn: JointType[DynamicalSystem, AutoDelaySupp],
+ syn: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
comm: DynamicalSystem,
out: JointType[DynamicalSystem, BindCondData],
@@ -897,7 +897,7 @@ def __init__(
# synaptic models
check.is_instance(pre, DynamicalSystem)
- check.is_instance(syn, JointType[DynamicalSystem, AutoDelaySupp])
+ check.is_instance(syn, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
check.is_instance(post, DynamicalSystem)
@@ -1002,7 +1002,7 @@ def update(self, inp):
def __init__(
self,
- pre: JointType[DynamicalSystem, AutoDelaySupp],
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
syn: DynamicalSystem,
comm: DynamicalSystem,
@@ -1015,7 +1015,7 @@ def __init__(
super().__init__(name=name, mode=mode)
# synaptic models
- check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, DynamicalSystem)
check.is_instance(comm, DynamicalSystem)
check.is_instance(out, JointType[DynamicalSystem, BindCondData])
@@ -1052,4 +1052,4 @@ def update(self):
spk = self.refs['delay'].at(self.name)
g = self.comm(self.syn(spk))
self.refs['out'].bind_cond(g)
- return g
+ return g
\ No newline at end of file
diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py
new file mode 100644
index 000000000..a85f6e1fc
--- /dev/null
+++ b/brainpy/_src/dyn/projections/plasticity.py
@@ -0,0 +1,238 @@
+from typing import Optional, Callable, Union
+
+from brainpy import math as bm, check
+from brainpy._src.delay import DelayAccess, delay_identifier, init_delay_by_return
+from brainpy._src.dyn.synapses.abstract_models import Expon
+from brainpy._src.dynsys import DynamicalSystem, Projection
+from brainpy._src.initialize import parameter
+from brainpy._src.mixin import (JointType, ParamDescInit, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP)
+from brainpy.types import ArrayType
+from .aligns import _AlignPost, _AlignPreMg, _get_return
+
+__all__ = [
+ 'STDP_Song2000',
+]
+
+
+class STDP_Song2000(Projection):
+ r"""Synaptic output with spike-time-dependent plasticity.
+
+ This model filters the synaptic currents according to the variables: :math:`w`.
+
+ .. math::
+
+ I_{syn}^+(t) = I_{syn}^-(t) * w
+
+ where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
+ and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse,
+ the conductance of the synapse will increase w.
+
+ The dynamics of :math:`w` is governed by the following equation:
+
+ .. math::
+
+ \begin{aligned}
+ \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
+ \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\
+ \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\
+ \tag{1}\end{aligned}
+
+ where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
+ of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike.
+
+ Example::
+ import brainpy as bp
+ import brainpy.math as bm
+
+ class STDPNet(bp.DynamicalSystem):
+ def __init__(self, num_pre, num_post):
+ super().__init__()
+ self.pre = bp.dyn.LifRef(num_pre, name='neu1')
+ self.post = bp.dyn.LifRef(num_post, name='neu2')
+ self.syn = bp.dyn.STDP_Song2000(
+ pre=self.pre,
+ delay=1.,
+ comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
+ weight=bp.init.Uniform(max_val=0.1)),
+ syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.post,
+ tau_s=16.8,
+ tau_t=33.7,
+ A1=0.96,
+ A2=0.53,
+ )
+
+ def update(self, I_pre, I_post):
+ self.syn()
+ self.pre(I_pre)
+ self.post(I_post)
+ conductance = self.syn.refs['syn'].g
+ Apre = self.syn.refs['pre_trace'].g
+ Apost = self.syn.refs['post_trace'].g
+ current = self.post.sum_inputs(self.post.V)
+ return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
+
+ duration = 300.
+ I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
+ [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255])
+ I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
+ [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])
+
+ net = STDPNet(1, 1)
+ def run(i, I_pre, I_post):
+ pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
+ return pre_spike, post_spike, g, Apre, Apost, current, W
+
+ indices = bm.arange(0, duration, bm.dt)
+ pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post], jit=True)
+
+ Args:
+ tau_s: float, ArrayType, Callable. The time constant of :math:`A_{pre}`.
+ tau_t: float, ArrayType, Callable. The time constant of :math:`A_{post}`.
+ A1: float, ArrayType, Callable. The increment of :math:`A_{pre}` produced by a spike.
+ A2: float, ArrayType, Callable. The increment of :math:`A_{post}` produced by a spike.
+ """
+
+ def __init__(
+ self,
+ pre: JointType[DynamicalSystem, SupportAutoDelay],
+ delay: Union[None, int, float],
+ syn: ParamDescInit[DynamicalSystem],
+ comm: DynamicalSystem,
+ out: ParamDescInit[JointType[DynamicalSystem, BindCondData]],
+ post: DynamicalSystem,
+ # synapse parameters
+ tau_s: Union[float, ArrayType, Callable] = 16.8,
+ tau_t: Union[float, ArrayType, Callable] = 33.7,
+ A1: Union[float, ArrayType, Callable] = 0.96,
+ A2: Union[float, ArrayType, Callable] = 0.53,
+ # others
+ out_label: Optional[str] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None,
+ ):
+ super().__init__(name=name, mode=mode)
+
+ # synaptic models
+ check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
+ check.is_instance(syn, ParamDescInit[DynamicalSystem])
+ check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP])
+ check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
+ check.is_instance(post, DynamicalSystem)
+ self.pre_num = pre.num
+ self.post_num = post.num
+ self.comm = comm
+ self.syn = syn
+
+ # delay initialization
+ if not pre.has_aft_update(delay_identifier):
+ delay_ins = init_delay_by_return(pre.return_info())
+ pre.add_aft_update(delay_identifier, delay_ins)
+ delay_cls = pre.get_aft_update(delay_identifier)
+ delay_cls.register_entry(self.name, delay)
+
+ if issubclass(syn.cls, AlignPost):
+ # synapse and output initialization
+ self._post_repr = f'{out_label} // {syn.identifier} // {out.identifier}'
+ if not post.has_bef_update(self._post_repr):
+ syn_cls = syn()
+ out_cls = out()
+ if out_label is None:
+ out_name = self.name
+ else:
+ out_name = f'{out_label} // {self.name}'
+ post.add_inp_fun(out_name, out_cls)
+ post.add_bef_update(self._post_repr, _AlignPost(syn_cls, out_cls))
+ # references
+ self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()``
+ self.refs['delay'] = pre.get_aft_update(delay_identifier)
+ self.refs['syn'] = post.get_bef_update(self._post_repr).syn # invisible to ``self.node()``
+ self.refs['out'] = post.get_bef_update(self._post_repr).out # invisible to ``self.node()``
+
+ else:
+ # synapse initialization
+ self._syn_id = f'Delay({str(delay)}) // {syn.identifier}'
+ if not delay_cls.has_bef_update(self._syn_id):
+ # delay
+ delay_access = DelayAccess(delay_cls, delay)
+ # synapse
+ syn_cls = syn()
+ # add to "after_updates"
+ delay_cls.add_bef_update(self._syn_id, _AlignPreMg(delay_access, syn_cls))
+
+ # output initialization
+ if out_label is None:
+ out_name = self.name
+ else:
+ out_name = f'{out_label} // {self.name}'
+ post.add_inp_fun(out_name, out)
+
+ # references
+ self.refs = dict(pre=pre, post=post) # invisible to `self.nodes()`
+ self.refs['delay'] = delay_cls.get_bef_update(self._syn_id)
+ self.refs['syn'] = delay_cls.get_bef_update(self._syn_id).syn
+ self.refs['out'] = out
+
+ self.refs['pre_trace'] = self.calculate_trace(pre, delay, Expon.desc(pre.num, tau=tau_s))
+ self.refs['post_trace'] = self.calculate_trace(post, None, Expon.desc(post.num, tau=tau_t))
+ # parameters
+ self.tau_s = parameter(tau_s, sizes=self.pre_num)
+ self.tau_t = parameter(tau_t, sizes=self.post_num)
+ self.A1 = parameter(A1, sizes=self.pre_num)
+ self.A2 = parameter(A2, sizes=self.post_num)
+
+ def calculate_trace(
+ self,
+ target: DynamicalSystem,
+ delay: Union[None, int, float],
+ syn: ParamDescInit[DynamicalSystem],
+ ):
+ """Calculate the trace of the target."""
+ check.is_instance(target, DynamicalSystem)
+ check.is_instance(syn, ParamDescInit[DynamicalSystem])
+
+ # delay initialization
+ if not target.has_aft_update(delay_identifier):
+ delay_ins = init_delay_by_return(target.return_info())
+ target.add_aft_update(delay_identifier, delay_ins)
+ delay_cls = target.get_aft_update(delay_identifier)
+ delay_cls.register_entry(target.name, delay)
+
+ # synapse initialization
+ _syn_id = f'Delay({str(delay)}) // {syn.identifier}'
+ if not delay_cls.has_bef_update(_syn_id):
+ # delay
+ delay_access = DelayAccess(delay_cls, delay)
+ # synapse
+ syn_cls = syn()
+ # add to "after_updates"
+ delay_cls.add_bef_update(_syn_id, _AlignPreMg(delay_access, syn_cls))
+
+ return delay_cls.get_bef_update(_syn_id).syn
+
+ def update(self):
+ # pre spikes, and pre-synaptic variables
+ if issubclass(self.syn.cls, AlignPost):
+ pre_spike = self.refs['delay'].at(self.name)
+ x = pre_spike
+ else:
+ pre_spike = self.refs['delay'].access()
+ x = _get_return(self.refs['syn'].return_info())
+
+ # post spikes
+ post_spike = self.refs['post'].spike
+
+ # weight updates
+ Apre = self.refs['pre_trace'].g
+ Apost = self.refs['post_trace'].g
+ delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike)
+ self.comm.update_STDP(delta_w)
+
+ # currents
+ current = self.comm(x)
+ if issubclass(self.syn.cls, AlignPost):
+ self.refs['syn'].add_current(current) # synapse post current
+ else:
+ self.refs['out'].bind_cond(current)
+ return current
diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py
new file mode 100644
index 000000000..b74aec5f9
--- /dev/null
+++ b/brainpy/_src/dyn/projections/tests/test_STDP.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+
+
+from absl.testing import parameterized
+
+import brainpy as bp
+import brainpy.math as bm
+
+class Test_STDP(parameterized.TestCase):
+ def test_STDP(self):
+ bm.random.seed()
+ class STDPNet(bp.DynamicalSystem):
+ def __init__(self, num_pre, num_post):
+ super().__init__()
+ self.pre = bp.dyn.LifRef(num_pre, name='neu1')
+ self.post = bp.dyn.LifRef(num_post, name='neu2')
+ self.syn = bp.dyn.STDP_Song2000(
+ pre=self.pre,
+ delay=1.,
+ comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
+ weight=lambda s: bm.Variable(bm.random.rand(*s) * 0.1)),
+ syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
+ out=bp.dyn.COBA.desc(E=0.),
+ post=self.post,
+ tau_s=16.8,
+ tau_t=33.7,
+ A1=0.96,
+ A2=0.53,
+ )
+
+ def update(self, I_pre, I_post):
+ self.syn()
+ self.pre(I_pre)
+ self.post(I_post)
+ conductance = self.syn.refs['syn'].g
+ Apre = self.syn.refs['pre_trace'].g
+ Apost = self.syn.refs['post_trace'].g
+ current = self.post.sum_inputs(self.post.V)
+ return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
+
+ duration = 300.
+ I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
+ [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255])
+ I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
+ [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])
+
+ net = STDPNet(1, 1)
+ def run(i, I_pre, I_post):
+ pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
+ return pre_spike, post_spike, g, Apre, Apost, current, W
+
+ indices = bm.arange(0, duration, bm.dt)
+ bm.for_loop(run, [indices, I_pre, I_post], jit=True)
diff --git a/brainpy/_src/dyn/synapses/abstract_models.py b/brainpy/_src/dyn/synapses/abstract_models.py
index f6efd89fe..3cbfac08e 100644
--- a/brainpy/_src/dyn/synapses/abstract_models.py
+++ b/brainpy/_src/dyn/synapses/abstract_models.py
@@ -1030,4 +1030,4 @@ def return_info(self):
lambda shape: self.u * self.x)
-STP.__doc__ = STP.__doc__ % (pneu_doc,)
+STP.__doc__ = STP.__doc__ % (pneu_doc,)
\ No newline at end of file
diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py
index 145eec585..c212884b7 100644
--- a/brainpy/_src/dynold/synapses/base.py
+++ b/brainpy/_src/dynold/synapses/base.py
@@ -11,7 +11,7 @@
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import parameter
from brainpy._src.mixin import (ParamDesc, JointType,
- AutoDelaySupp, BindCondData, ReturnInfo)
+ SupportAutoDelay, BindCondData, ReturnInfo)
from brainpy.errors import UnsupportedError
from brainpy.types import ArrayType
@@ -109,7 +109,7 @@ def update(self):
pass
-class _SynSTP(_SynapseComponent, ParamDesc, AutoDelaySupp):
+class _SynSTP(_SynapseComponent, ParamDesc, SupportAutoDelay):
"""Base class for synaptic short-term plasticity."""
def update(self, pre_spike):
diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py
index 78ea721c7..770d4bf30 100644
--- a/brainpy/_src/dynsys.py
+++ b/brainpy/_src/dynsys.py
@@ -10,7 +10,7 @@
from brainpy import tools, math as bm
from brainpy._src.initialize import parameter, variable_
-from brainpy._src.mixin import AutoDelaySupp, Container, ReceiveInputProj, DelayRegister, global_delay_data
+from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
from brainpy._src.deprecations import _update_deprecate_msg
@@ -70,7 +70,7 @@ def update(self, x):
return func
-class DynamicalSystem(bm.BrainPyObject, DelayRegister, ReceiveInputProj):
+class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj):
"""Base Dynamical System class.
.. note::
@@ -487,7 +487,7 @@ class Network(DynSysGroup):
pass
-class Sequential(DynamicalSystem, AutoDelaySupp, Container):
+class Sequential(DynamicalSystem, SupportAutoDelay, Container):
"""A sequential `input-output` module.
Modules will be added to it in the order they are passed in the
@@ -557,9 +557,9 @@ def update(self, x):
def return_info(self):
last = self[-1]
- if not isinstance(last, AutoDelaySupp):
+ if not isinstance(last, SupportAutoDelay):
raise UnsupportedError(f'Does not support "return_info()" because the last node is '
- f'not instance of {AutoDelaySupp.__name__}')
+ f'not instance of {SupportAutoDelay.__name__}')
return last.return_info()
def __getitem__(self, key: Union[int, slice, str]):
diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py
index daa8a55bb..061bfe472 100644
--- a/brainpy/_src/math/object_transform/base.py
+++ b/brainpy/_src/math/object_transform/base.py
@@ -141,6 +141,8 @@ def fun(self):
# that has been created.
a = self.tracing_variable('a', bm.zeros, (10,))
+ .. versionadded:: 2.4.5
+
Args:
name: str. The variable name.
init: callable, Array. The data to be initialized as a ``Variable``.
diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py
index 4a9165420..61c7b7f0d 100644
--- a/brainpy/_src/math/object_transform/controls.py
+++ b/brainpy/_src/math/object_transform/controls.py
@@ -526,25 +526,17 @@ def cond(
node_deprecation(child_objs)
dyn_vars = get_stack_cache((true_fun, false_fun))
- _transform = _get_cond_transform(VariableStack() if dyn_vars is None else dyn_vars,
- pred,
- true_fun,
- false_fun)
- if jax.config.jax_disable_jit:
- dyn_values, res = _transform(operands)
-
- else:
+ if not jax.config.jax_disable_jit:
if dyn_vars is None:
with new_transform('cond'):
- dyn_vars, rets = evaluate_dyn_vars(
- _transform,
- operands,
- use_eval_shape=current_transform_number() <= 1
- )
+ dyn_vars1, rets = evaluate_dyn_vars(true_fun, *operands, use_eval_shape=current_transform_number() <= 1)
+ dyn_vars2, rets = evaluate_dyn_vars(false_fun, *operands, use_eval_shape=current_transform_number() <= 1)
+ dyn_vars = dyn_vars1 + dyn_vars2
cache_stack((true_fun, false_fun), dyn_vars)
if current_transform_number() > 0:
- return rets[1]
- dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands)
+ return rets
+ dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
+ dyn_values, res = _get_cond_transform(dyn_vars, pred, true_fun, false_fun)(operands)
for k in dyn_values.keys():
dyn_vars[k]._value = dyn_values[k]
return res
@@ -1009,22 +1001,17 @@ def while_loop(
if not isinstance(operands, (list, tuple)):
operands = (operands,)
- if jax.config.jax_disable_jit:
- dyn_vars = VariableStack()
-
- else:
- dyn_vars = get_stack_cache(body_fun)
-
+ dyn_vars = get_stack_cache((body_fun, cond_fun))
+ if not jax.config.jax_disable_jit:
if dyn_vars is None:
with new_transform('while_loop'):
- dyn_vars, rets = evaluate_dyn_vars(
- _get_while_transform(cond_fun, body_fun, VariableStack()),
- operands
- )
- cache_stack(body_fun, dyn_vars)
+ dyn_vars1, _ = evaluate_dyn_vars(cond_fun, *operands, use_eval_shape=current_transform_number() <= 1)
+ dyn_vars2, rets = evaluate_dyn_vars(body_fun, *operands, use_eval_shape=current_transform_number() <= 1)
+ dyn_vars = dyn_vars1 + dyn_vars2
+ cache_stack((body_fun, cond_fun), dyn_vars)
if current_transform_number():
- return rets[1]
-
+ return rets
+ dyn_vars = VariableStack() if dyn_vars is None else dyn_vars
dyn_values, out = _get_while_transform(cond_fun, body_fun, dyn_vars)(operands)
for k, v in dyn_vars.items():
v._value = dyn_values[k]
diff --git a/brainpy/_src/math/object_transform/tests/test_controls.py b/brainpy/_src/math/object_transform/tests/test_controls.py
index 7203adb6f..5295d80db 100644
--- a/brainpy/_src/math/object_transform/tests/test_controls.py
+++ b/brainpy/_src/math/object_transform/tests/test_controls.py
@@ -132,6 +132,13 @@ def update(self):
self.assertTrue(bm.allclose(cls.a, 10.))
+class TestCond(unittest.TestCase):
+ def test1(self):
+ bm.random.seed(1)
+ bm.cond(True, lambda: bm.random.random(10), lambda: bm.random.random(10), ())
+ bm.cond(False, lambda: bm.random.random(10), lambda: bm.random.random(10), ())
+
+
class TestIfElse(unittest.TestCase):
def test1(self):
def f(a):
@@ -221,6 +228,33 @@ def body(x, y):
print()
print(res)
+ def test2(self):
+ bm.random.seed()
+
+ a = bm.Variable(bm.zeros(1))
+ b = bm.Variable(bm.ones(1))
+
+ def cond(x, y):
+ return x < 6.
+
+ def body(x, y):
+ a.value += x
+ b.value *= y
+ return x + b[0], y + 1.
+
+ res = bm.while_loop(body, cond, operands=(1., 1.))
+ print()
+ print(res)
+
+ with jax.disable_jit():
+ a = bm.Variable(bm.zeros(1))
+ b = bm.Variable(bm.ones(1))
+
+ res2 = bm.while_loop(body, cond, operands=(1., 1.))
+ print(res2)
+ self.assertTrue(bm.array_equal(res2[0], res[0]))
+ self.assertTrue(bm.array_equal(res2[1], res[1]))
+
def test3(self):
bm.random.seed()
@@ -242,32 +276,27 @@ def body(x, y):
print(a)
print(b)
- def test2(self):
+ def test4(self):
bm.random.seed()
a = bm.Variable(bm.zeros(1))
b = bm.Variable(bm.ones(1))
def cond(x, y):
- return x < 6.
+ a.value += 1
+ return bm.all(a.value < 6.)
def body(x, y):
a.value += x
b.value *= y
- return x + b[0], y + 1.
res = bm.while_loop(body, cond, operands=(1., 1.))
- print()
+ self.assertTrue(bm.allclose(a, 5.))
+ self.assertTrue(bm.allclose(b, 1.))
print(res)
-
- with jax.disable_jit():
- a = bm.Variable(bm.zeros(1))
- b = bm.Variable(bm.ones(1))
-
- res2 = bm.while_loop(body, cond, operands=(1., 1.))
- print(res2)
- self.assertTrue(bm.array_equal(res2[0], res[0]))
- self.assertTrue(bm.array_equal(res2[1], res[1]))
+ print(a)
+ print(b)
+ print()
class TestDebugAndCompile(parameterized.TestCase):
diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py
index eb04c5d2e..e989908a0 100644
--- a/brainpy/_src/math/random.py
+++ b/brainpy/_src/math/random.py
@@ -22,7 +22,7 @@
__all__ = [
'RandomState', 'Generator', 'DEFAULT',
- 'seed', 'default_rng', 'split_key',
+ 'seed', 'default_rng', 'split_key', 'split_keys',
# numpy compatibility
'rand', 'randint', 'random_integers', 'randn', 'random',
@@ -1258,6 +1258,8 @@ def split_keys(n):
internally by `pmap` and `vmap` to ensure that random numbers
are different in parallel threads.
+ .. versionadded:: 2.4.5
+
Parameters
----------
n : int
@@ -1267,6 +1269,15 @@ def split_keys(n):
def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
+ """Clone the random state according to the given setting.
+
+ Args:
+ seed_or_key: The seed (an integer) or the random key.
+ clone: Bool. Whether clone the default random state.
+
+ Returns:
+ The random state.
+ """
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
diff --git a/brainpy/_src/math/surrogate/_one_input.py b/brainpy/_src/math/surrogate/_one_input.py
index 5ddb94254..23f151ee0 100644
--- a/brainpy/_src/math/surrogate/_one_input.py
+++ b/brainpy/_src/math/surrogate/_one_input.py
@@ -36,6 +36,13 @@
class Sigmoid(Surrogate):
+ """Spike function with the sigmoid-shaped surrogate gradient.
+
+ See Also
+ --------
+ sigmoid
+
+ """
def __init__(self, alpha=4., origin=False):
self.alpha = alpha
self.origin = origin
@@ -118,6 +125,13 @@ def grad(dz):
class PiecewiseQuadratic(Surrogate):
+ """Judge spiking state with a piecewise quadratic function.
+
+ See Also
+ --------
+ piecewise_quadratic
+
+ """
def __init__(self, alpha=1., origin=False):
self.alpha = alpha
self.origin = origin
@@ -220,6 +234,12 @@ def grad(dz):
class PiecewiseExp(Surrogate):
+ """Judge spiking state with a piecewise exponential function.
+
+ See Also
+ --------
+ piecewise_exp
+ """
def __init__(self, alpha=1., origin=False):
self.alpha = alpha
self.origin = origin
@@ -308,6 +328,12 @@ def grad(dz):
class SoftSign(Surrogate):
+ """Judge spiking state with a soft sign function.
+
+ See Also
+ --------
+ soft_sign
+ """
def __init__(self, alpha=1., origin=False):
self.alpha = alpha
self.origin = origin
@@ -391,6 +417,12 @@ def grad(dz):
class Arctan(Surrogate):
+ """Judge spiking state with an arctan function.
+
+ See Also
+ --------
+ arctan
+ """
def __init__(self, alpha=1., origin=False):
self.alpha = alpha
self.origin = origin
@@ -473,6 +505,12 @@ def grad(dz):
class NonzeroSignLog(Surrogate):
+ """Judge spiking state with a nonzero sign log function.
+
+ See Also
+ --------
+ nonzero_sign_log
+ """
def __init__(self, alpha=1., origin=False):
self.alpha = alpha
self.origin = origin
@@ -568,6 +606,12 @@ def grad(dz):
class ERF(Surrogate):
+ """Judge spiking state with an erf function.
+
+ See Also
+ --------
+ erf
+ """
def __init__(self, alpha=1., origin=False):
self.alpha = alpha
self.origin = origin
@@ -660,6 +704,12 @@ def grad(dz):
class PiecewiseLeakyRelu(Surrogate):
+ """Judge spiking state with a piecewise leaky relu function.
+
+ See Also
+ --------
+ piecewise_leaky_relu
+ """
def __init__(self, c=0.01, w=1., origin=False):
self.c = c
self.w = w
@@ -771,6 +821,12 @@ def grad(dz):
class SquarewaveFourierSeries(Surrogate):
+ """Judge spiking state with a squarewave fourier series.
+
+ See Also
+ --------
+ squarewave_fourier_series
+ """
def __init__(self, n=2, t_period=8., origin=False):
self.n = n
self.t_period = t_period
@@ -863,6 +919,12 @@ def grad(dz):
class S2NN(Surrogate):
+ """Judge spiking state with the S2NN surrogate spiking function.
+
+ See Also
+ --------
+ s2nn
+ """
def __init__(self, alpha=4., beta=1., epsilon=1e-8, origin=False):
self.alpha = alpha
self.beta = beta
@@ -969,6 +1031,12 @@ def grad(dz):
class QPseudoSpike(Surrogate):
+ """Judge spiking state with the q-PseudoSpike surrogate function.
+
+ See Also
+ --------
+ q_pseudo_spike
+ """
def __init__(self, alpha=2., origin=False):
self.alpha = alpha
self.origin = origin
@@ -1062,6 +1130,12 @@ def grad(dz):
class LeakyRelu(Surrogate):
+ """Judge spiking state with the Leaky ReLU function.
+
+ See Also
+ --------
+ leaky_relu
+ """
def __init__(self, alpha=0.1, beta=1., origin=False):
self.alpha = alpha
self.beta = beta
@@ -1156,6 +1230,12 @@ def grad(dz):
class LogTailedRelu(Surrogate):
+ """Judge spiking state with the Log-tailed ReLU function.
+
+ See Also
+ --------
+ log_tailed_relu
+ """
def __init__(self, alpha=0., origin=False):
self.alpha = alpha
self.origin = origin
@@ -1260,6 +1340,12 @@ def grad(dz):
class ReluGrad(Surrogate):
+ """Judge spiking state with the ReLU gradient function.
+
+ See Also
+ --------
+ relu_grad
+ """
def __init__(self, alpha=0.3, width=1.):
self.alpha = alpha
self.width = width
@@ -1337,6 +1423,12 @@ def grad(dz):
class GaussianGrad(Surrogate):
+ """Judge spiking state with the Gaussian gradient function.
+
+ See Also
+ --------
+ gaussian_grad
+ """
def __init__(self, sigma=0.5, alpha=0.5):
self.sigma = sigma
self.alpha = alpha
@@ -1413,6 +1505,12 @@ def grad(dz):
class MultiGaussianGrad(Surrogate):
+ """Judge spiking state with the multi-Gaussian gradient function.
+
+ See Also
+ --------
+ multi_gaussian_grad
+ """
def __init__(self, h=0.15, s=6.0, sigma=0.5, scale=0.5):
self.h = h
self.s = s
@@ -1503,6 +1601,12 @@ def grad(dz):
class InvSquareGrad(Surrogate):
+ """Judge spiking state with the inverse-square surrogate gradient function.
+
+ See Also
+ --------
+ inv_square_grad
+ """
def __init__(self, alpha=100.):
self.alpha = alpha
@@ -1571,6 +1675,12 @@ def grad(dz):
class SlayerGrad(Surrogate):
+ """Judge spiking state with the slayer surrogate gradient function.
+
+ See Also
+ --------
+ slayer_grad
+ """
def __init__(self, alpha=1.):
self.alpha = alpha
diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py
index fce2aca18..23cd703bf 100644
--- a/brainpy/_src/mixin.py
+++ b/brainpy/_src/mixin.py
@@ -1,6 +1,5 @@
import numbers
import sys
-import warnings
from dataclasses import dataclass
from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any
from typing import (_SpecialForm, _type_check, _remove_dups_flatten)
@@ -28,11 +27,15 @@
'ParamDesc',
'ParamDescInit',
'AlignPost',
- 'AutoDelaySupp',
'Container',
'TreeNode',
'BindCondData',
'JointType',
+ 'SupportSTDP',
+ 'SupportAutoDelay',
+ 'SupportInputProj',
+ 'SupportOnline',
+ 'SupportOffline',
]
global_delay_data = dict()
@@ -46,59 +49,6 @@ class MixIn(object):
pass
-class ReceiveInputProj(MixIn):
- """The :py:class:`~.MixIn` that receives the input projections.
-
- Note that the subclass should define a ``cur_inputs`` attribute.
-
- """
- cur_inputs: bm.node_dict
-
- def add_inp_fun(self, key: Any, fun: Callable):
- """Add an input function.
-
- Args:
- key: The dict key.
- fun: The function to generate inputs.
- """
- if not callable(fun):
- raise TypeError('Must be a function.')
- if key in self.cur_inputs:
- raise ValueError(f'Key "{key}" has been defined and used.')
- self.cur_inputs[key] = fun
-
- def get_inp_fun(self, key):
- """Get the input function.
-
- Args:
- key: The key.
-
- Returns:
- The input function which generates currents.
- """
- return self.cur_inputs.get(key)
-
- def sum_inputs(self, *args, init=0., label=None, **kwargs):
- """Summarize all inputs by the defined input functions ``.cur_inputs``.
-
- Args:
- *args: The arguments for input functions.
- init: The initial input data.
- **kwargs: The arguments for input functions.
-
- Returns:
- The total currents.
- """
- if label is None:
- for key, out in self.cur_inputs.items():
- init = init + out(*args, **kwargs)
- else:
- for key, out in self.cur_inputs.items():
- if key.startswith(label + ' // '):
- init = init + out(*args, **kwargs)
- return init
-
-
class ParamDesc(MixIn):
""":py:class:`~.MixIn` indicates the function for describing initialization parameters.
@@ -207,13 +157,6 @@ def get_data(self):
return init
-class AutoDelaySupp(MixIn):
- """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""
-
- def return_info(self) -> Union[bm.Variable, ReturnInfo]:
- raise NotImplementedError('Must implement the "return_info()" function.')
-
-
class Container(MixIn):
"""Container :py:class:`~.MixIn` which wrap a group of objects.
"""
@@ -347,7 +290,7 @@ def register_delay_at(
if delay_identifier is None: from brainpy._src.delay import delay_identifier
if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem
- assert isinstance(self, AutoDelaySupp), f'self must be an instance of {AutoDelaySupp.__name__}'
+ assert isinstance(self, SupportAutoDelay), f'self must be an instance of {SupportAutoDelay.__name__}'
assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}'
if not self.has_aft_update(delay_identifier):
self.add_aft_update(delay_identifier, init_delay_by_return(self.return_info()))
@@ -549,8 +492,97 @@ def get_delay_var(self, name):
return global_delay_data[name]
+class SupportInputProj(MixIn):
+ """The :py:class:`~.MixIn` that receives the input projections.
+
+ Note that the subclass should define a ``cur_inputs`` attribute.
+
+ """
+ cur_inputs: bm.node_dict
+
+ def add_inp_fun(self, key: Any, fun: Callable):
+ """Add an input function.
+
+ Args:
+ key: The dict key.
+ fun: The function to generate inputs.
+ """
+ if not callable(fun):
+ raise TypeError('Must be a function.')
+ if key in self.cur_inputs:
+ raise ValueError(f'Key "{key}" has been defined and used.')
+ self.cur_inputs[key] = fun
+
+ def get_inp_fun(self, key):
+ """Get the input function.
+
+ Args:
+ key: The key.
+
+ Returns:
+ The input function which generates currents.
+ """
+ return self.cur_inputs.get(key)
+
+ def sum_inputs(self, *args, init=0., label=None, **kwargs):
+ """Summarize all inputs by the defined input functions ``.cur_inputs``.
+
+ Args:
+ *args: The arguments for input functions.
+ init: The initial input data.
+ **kwargs: The arguments for input functions.
+
+ Returns:
+ The total currents.
+ """
+ if label is None:
+ for key, out in self.cur_inputs.items():
+ init = init + out(*args, **kwargs)
+ else:
+ for key, out in self.cur_inputs.items():
+ if key.startswith(label + ' // '):
+ init = init + out(*args, **kwargs)
+ return init
+
+
+class SupportAutoDelay(MixIn):
+ """``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""
+
+ def return_info(self) -> Union[bm.Variable, ReturnInfo]:
+ raise NotImplementedError('Must implement the "return_info()" function.')
+
+
+class SupportOnline(MixIn):
+ """:py:class:`~.MixIn` to support the online training methods.
+
+ .. versionadded:: 2.4.5
+ """
+
+ online_fit_by: Optional # methods for online fitting
+
+ def online_init(self):
+ raise NotImplementedError
+
+ def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]):
+ raise NotImplementedError
+
+
+class SupportOffline(MixIn):
+ """:py:class:`~.MixIn` to support the offline training methods.
+
+ .. versionadded:: 2.4.5
+ """
+
+ offline_fit_by: Optional # methods for offline fitting
+
+ def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]):
+ raise NotImplementedError
+
+
class BindCondData(MixIn):
"""Bind temporary conductance data.
+
+
"""
_conductance: Optional
@@ -561,6 +593,16 @@ def unbind_cond(self):
self._conductance = None
+class SupportSTDP(MixIn):
+ """Support synaptic plasticity by modifying the weights.
+ """
+ def update_STDP(self,
+ dW: Union[bm.Array, jax.Array],
+ constraints: Optional[Callable] = None,
+ ):
+ raise NotImplementedError
+
+
T = TypeVar('T')
@@ -598,7 +640,7 @@ class UnionType2(MixIn):
>>> import brainpy as bp
>>>
- >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.AutoDelaySupp])
+ >>> isinstance(bp.dyn.Expon(1), JointType[bp.DynamicalSystem, bp.mixin.ParamDesc, bp.mixin.SupportAutoDelay])
"""
@classmethod
diff --git a/brainpy/dyn/__init__.py b/brainpy/dyn/__init__.py
index 297c0c50b..00587fb06 100644
--- a/brainpy/dyn/__init__.py
+++ b/brainpy/dyn/__init__.py
@@ -5,6 +5,7 @@
from .neurons import *
from .synapses import *
from .projections import *
+from .plasticity import *
from .others import *
from .outs import *
from .rates import *
diff --git a/brainpy/dyn/plasticity.py b/brainpy/dyn/plasticity.py
new file mode 100644
index 000000000..db978b390
--- /dev/null
+++ b/brainpy/dyn/plasticity.py
@@ -0,0 +1,3 @@
+from brainpy._src.dyn.projections.plasticity import (
+ STDP_Song2000 as STDP_Song2000,
+)
diff --git a/brainpy/dyn/projections.py b/brainpy/dyn/projections.py
index 6ee6f300a..2954b7871 100644
--- a/brainpy/dyn/projections.py
+++ b/brainpy/dyn/projections.py
@@ -1,5 +1,4 @@
-
from brainpy._src.dyn.projections.aligns import (
VanillaProj,
ProjAlignPostMg1,
@@ -20,3 +19,4 @@
PoissonInput as PoissonInput,
)
+
diff --git a/brainpy/mixin.py b/brainpy/mixin.py
index a3f17c7aa..ab3c3cd37 100644
--- a/brainpy/mixin.py
+++ b/brainpy/mixin.py
@@ -1,13 +1,14 @@
from brainpy._src.mixin import (
MixIn as MixIn,
- ReceiveInputProj as ReceiveInputProj,
+ SupportInputProj as SupportInputProj,
AlignPost as AlignPost,
- AutoDelaySupp as AutoDelaySupp,
+ SupportAutoDelay as SupportAutoDelay,
ParamDesc as ParamDesc,
ParamDescInit as ParamDescInit,
BindCondData as BindCondData,
Container as Container,
TreeNode as TreeNode,
JointType as JointType,
+ SupportSTDP as SupportPlasticity,
)
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 000000000..aa728cada
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,23 @@
+FROM ubuntu:22.04
+
+ENV TZ=Asia/Dubai
+RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
+
+RUN apt update
+RUN apt install -y --no-install-recommends software-properties-common
+
+RUN apt update && apt install -y python3-pip
+
+RUN ln -sf /usr/bin/python3.10 /usr/bin/python && \
+ ln -sf /usr/bin/pip3 /usr/bin/pip
+
+
+RUN pip --no-cache-dir install --upgrade pip && \
+ pip --no-cache-dir install --upgrade setuptools && \
+ pip --no-cache-dir install --upgrade wheel
+
+ADD . /usr/src/app
+WORKDIR /usr/src/app
+
+RUN pip --no-cache-dir install --upgrade "jax[cpu]"
+RUN pip --no-cache-dir install --upgrade -r requirements.txt
diff --git a/docker/requirements.txt b/docker/requirements.txt
new file mode 100644
index 000000000..460371906
--- /dev/null
+++ b/docker/requirements.txt
@@ -0,0 +1,16 @@
+numpy
+tqdm
+msgpack
+matplotlib>=3.4
+jax
+jaxlib
+scipy>=1.1.0
+brainpy
+brainpylib
+brainpy_datasets
+h5py
+pathos
+
+# test requirements
+pytest
+absl-py
diff --git a/docs/_templates/class_template.rst b/docs/_templates/class_template.rst
index d9135b2c1..a902dc6d9 100644
--- a/docs/_templates/class_template.rst
+++ b/docs/_templates/class_template.rst
@@ -5,7 +5,9 @@
.. autoclass:: {{ objname }}
- .. automethod:: __init__
+ {% for item in methods %}
+ .. automethod:: {{ item }}
+ {%- endfor %}
{% block methods %}
diff --git a/docs/_templates/classtemplate.rst b/docs/_templates/classtemplate.rst
new file mode 100644
index 000000000..57b89b777
--- /dev/null
+++ b/docs/_templates/classtemplate.rst
@@ -0,0 +1,10 @@
+.. role:: hidden
+ :class: hidden-section
+.. currentmodule:: {{ module }}
+
+
+{{ name | underline}}
+
+.. autoclass:: {{ name }}
+ :members:
+
diff --git a/docs/advanced_tutorials.rst b/docs/advanced_tutorials.rst
index 1cb343846..5c8cba0fd 100644
--- a/docs/advanced_tutorials.rst
+++ b/docs/advanced_tutorials.rst
@@ -4,44 +4,12 @@ Advanced Tutorials
This section contains tutorials that illustrate more advanced features of BrainPy.
-
-Advanced math
--------------
-
-
-.. toctree::
- :maxdepth: 1
-
- tutorial_advanced/differentiation.ipynb
-
-
-
-Interoperation
---------------
-
-
-.. toctree::
- :maxdepth: 1
-
- tutorial_advanced/integrate_flax_into_brainpy.ipynb
- tutorial_advanced/integrate_bp_lif_into_flax.ipynb
- tutorial_advanced/integrate_bp_convlstm_into_flax.ipynb
-
-
-Advanced dynamics analysis
---------------------------
-
-.. toctree::
- :maxdepth: 1
-
- tutorial_advanced/advanced_lowdim_analysis.ipynb
-
-
-Developer guides
----------------
-
.. toctree::
- :maxdepth: 1
+ :maxdepth: 2
- tutorial_advanced/contributing.md
+ tutorial_advanced/1_advanced_math.rst
+ tutorial_advanced/2_interoperation.rst
+ tutorial_advanced/3_dedicated_operators.rst
+ tutorial_advanced/4_developer_guides.rst
+ tutorial_advanced/5_others.rst
diff --git a/docs/api.rst b/docs/api.rst
index 65bc5b088..076ce48c9 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -5,31 +5,31 @@ API Documentation
:maxdepth: 1
apis/auto/changelog.rst
- apis/auto/brainpy.rst
- apis/auto/math.rst
- apis/auto/dnn.rst
- apis/auto/dyn.rst
- apis/auto/integrators.rst
- apis/auto/analysis.rst
- apis/auto/connect.rst
- apis/auto/encoding.rst
- apis/auto/initialize.rst
- apis/auto/inputs.rst
- apis/auto/losses.rst
- apis/auto/measure.rst
- apis/auto/optim.rst
- apis/auto/running.rst
- apis/auto/mixin.rst
+ apis/brainpy.rst
+ apis/math.rst
+ apis/dnn.rst
+ apis/dyn.rst
+ apis/integrators.rst
+ apis/analysis.rst
+ apis/connect.rst
+ apis/encoding.rst
+ apis/initialize.rst
+ apis/inputs.rst
+ apis/losses.rst
+ apis/measure.rst
+ apis/optim.rst
+ apis/running.rst
+ apis/mixin.rst
The following APIs will no longer be maintained in the future, but you can still use them normally.
.. toctree::
:maxdepth: 1
- apis/channels.rst
- apis/neurons.rst
- apis/rates.rst
- apis/synapses.rst
- apis/synouts.rst
- apis/synplast.rst
- apis/layers.rst
+ apis/deprecated/channels.rst
+ apis/deprecated/neurons.rst
+ apis/deprecated/rates.rst
+ apis/deprecated/synapses.rst
+ apis/deprecated/synouts.rst
+ apis/deprecated/synplast.rst
+ apis/deprecated/layers.rst
diff --git a/docs/apis/analysis.rst b/docs/apis/analysis.rst
new file mode 100644
index 000000000..897fa46c1
--- /dev/null
+++ b/docs/apis/analysis.rst
@@ -0,0 +1,37 @@
+``brainpy.analysis`` module
+===========================
+
+.. currentmodule:: brainpy.analysis
+.. automodule:: brainpy.analysis
+
+.. contents::
+ :local:
+ :depth: 1
+
+Low-dimensional Analyzers
+-------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ PhasePlane1D
+ PhasePlane2D
+ Bifurcation1D
+ Bifurcation2D
+ FastSlow1D
+ FastSlow2D
+
+
+High-dimensional Analyzers
+--------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ SlowPointFinder
+
+
diff --git a/docs/apis/brainpy.dyn.base.rst b/docs/apis/brainpy.dyn.base.rst
new file mode 100644
index 000000000..25d794f7e
--- /dev/null
+++ b/docs/apis/brainpy.dyn.base.rst
@@ -0,0 +1,14 @@
+Base Classes
+============
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ NeuDyn
+ SynDyn
+ IonChaDyn
diff --git a/docs/apis/brainpy.dyn.channels.rst b/docs/apis/brainpy.dyn.channels.rst
new file mode 100644
index 000000000..80a1af30d
--- /dev/null
+++ b/docs/apis/brainpy.dyn.channels.rst
@@ -0,0 +1,101 @@
+Ion Channel Dynamics
+====================
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+.. contents::
+ :local:
+ :depth: 1
+
+
+Base Classes
+------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ IonChannel
+
+
+
+Calcium Channels
+-----------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ CalciumChannel
+ ICaN_IS2008
+ ICaT_HM1992
+ ICaT_HP1992
+ ICaHT_HM1992
+ ICaHT_Re1993
+ ICaL_IS2008
+
+
+Potassium Channels
+------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ PotassiumChannel
+ IKDR_Ba2002v2
+ IK_TM1991v2
+ IK_HH1952v2
+ IKA1_HM1992v2
+ IKA2_HM1992v2
+ IKK2A_HM1992v2
+ IKK2B_HM1992v2
+ IKNI_Ya1989v2
+ IK_Leak
+ IKDR_Ba2002
+ IK_TM1991
+ IK_HH1952
+ IKA1_HM1992
+ IKA2_HM1992
+ IKK2A_HM1992
+ IKK2B_HM1992
+ IKNI_Ya1989
+ IKL
+
+
+
+Sodium Channels
+------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ SodiumChannel
+ INa_Ba2002
+ INa_TM1991
+ INa_HH1952
+ INa_Ba2002v2
+ INa_TM1991v2
+ INa_HH1952v2
+
+
+Other Channels
+------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Ih_HM1992
+ Ih_De1996
+ IAHP_De1994v2
+ IAHP_De1994
+ LeakyChannel
+ IL
diff --git a/docs/apis/brainpy.dyn.ions.rst b/docs/apis/brainpy.dyn.ions.rst
new file mode 100644
index 000000000..5d18643b2
--- /dev/null
+++ b/docs/apis/brainpy.dyn.ions.rst
@@ -0,0 +1,23 @@
+Ion Dynamics
+======================
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ mix_ions
+ Ion
+ MixIons
+ Calcium
+ CalciumFixed
+ CalciumDetailed
+ CalciumFirstOrder
+ Sodium
+ SodiumFixed
+ Potassium
+ PotassiumFixed
diff --git a/docs/apis/brainpy.dyn.neurons.rst b/docs/apis/brainpy.dyn.neurons.rst
new file mode 100644
index 000000000..980d18516
--- /dev/null
+++ b/docs/apis/brainpy.dyn.neurons.rst
@@ -0,0 +1,72 @@
+Neuron Dynamics
+===============
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+
+Reduced Neuron Models
+---------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Lif
+ LifLTC
+ LifRefLTC
+ LifRef
+ ExpIF
+ ExpIFLTC
+ ExpIFRefLTC
+ ExpIFRef
+ AdExIF
+ AdExIFLTC
+ AdExIFRefLTC
+ AdExIFRef
+ QuaIF
+ QuaIFLTC
+ QuaIFRefLTC
+ QuaIFRef
+ AdQuaIF
+ AdQuaIFLTC
+ AdQuaIFRefLTC
+ AdQuaIFRef
+ Gif
+ GifLTC
+ GifRefLTC
+ GifRef
+ Izhikevich
+ IzhikevichLTC
+ IzhikevichRefLTC
+ IzhikevichRef
+ HHTypedNeuron
+ CondNeuGroupLTC
+ CondNeuGroup
+ HH
+ HHLTC
+ MorrisLecar
+ MorrisLecarLTC
+ WangBuzsakiHH
+ WangBuzsakiHHLTC
+
+
+Hodgkin–Huxley Neuron Models
+----------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ HHTypedNeuron
+ CondNeuGroupLTC
+ CondNeuGroup
+ HH
+ HHLTC
+ MorrisLecar
+ MorrisLecarLTC
+ WangBuzsakiHH
+ WangBuzsakiHHLTC
+
diff --git a/docs/apis/brainpy.dyn.others.rst b/docs/apis/brainpy.dyn.others.rst
new file mode 100644
index 000000000..aae94ff63
--- /dev/null
+++ b/docs/apis/brainpy.dyn.others.rst
@@ -0,0 +1,21 @@
+Common Dynamical Models
+======================
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Leaky
+ Integrator
+ InputGroup
+ OutputGroup
+ SpikeTimeGroup
+ PoissonGroup
+ OUProcess
+
+
+
diff --git a/docs/apis/brainpy.dyn.outs.rst b/docs/apis/brainpy.dyn.outs.rst
new file mode 100644
index 000000000..892f700e2
--- /dev/null
+++ b/docs/apis/brainpy.dyn.outs.rst
@@ -0,0 +1,16 @@
+Synaptic Outputs
+================
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ SynOut
+ COBA
+ CUBA
+ MgBlock
\ No newline at end of file
diff --git a/docs/apis/brainpy.dyn.plasticity.rst b/docs/apis/brainpy.dyn.plasticity.rst
new file mode 100644
index 000000000..597c71aa5
--- /dev/null
+++ b/docs/apis/brainpy.dyn.plasticity.rst
@@ -0,0 +1,12 @@
+Synaptic Plasticity
+===================
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ STDP_Song2000
diff --git a/docs/apis/brainpy.dyn.projections.rst b/docs/apis/brainpy.dyn.projections.rst
new file mode 100644
index 000000000..b1dcb1219
--- /dev/null
+++ b/docs/apis/brainpy.dyn.projections.rst
@@ -0,0 +1,24 @@
+Synaptic Projections
+======================
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ VanillaProj
+ ProjAlignPostMg1
+ ProjAlignPostMg2
+ ProjAlignPost1
+ ProjAlignPost2
+ ProjAlignPreMg1
+ ProjAlignPreMg2
+ ProjAlignPre1
+ ProjAlignPre2
+ SynConn
+ PoissonInput
diff --git a/docs/apis/brainpy.dyn.rates.rst b/docs/apis/brainpy.dyn.rates.rst
new file mode 100644
index 000000000..8aa9af007
--- /dev/null
+++ b/docs/apis/brainpy.dyn.rates.rst
@@ -0,0 +1,20 @@
+Population Rate Models
+======================
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ FHN
+ FeedbackFHN
+ QIF
+ StuartLandauOscillator
+ WilsonCowanModel
+ ThresholdLinearModel
+
+
diff --git a/docs/apis/brainpy.dyn.synapses.rst b/docs/apis/brainpy.dyn.synapses.rst
new file mode 100644
index 000000000..59062d180
--- /dev/null
+++ b/docs/apis/brainpy.dyn.synapses.rst
@@ -0,0 +1,25 @@
+Synaptic Dynamics
+======================
+
+.. currentmodule:: brainpy.dyn
+.. automodule:: brainpy.dyn
+
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Delta
+ Expon
+ Alpha
+ DualExpon
+ DualExponV2
+ NMDA
+ STD
+ STP
+ AMPA
+ GABAa
+ BioNMDA
+ DiffusiveCoupling
+ AdditiveCoupling
\ No newline at end of file
diff --git a/docs/apis/brainpy.rst b/docs/apis/brainpy.rst
new file mode 100644
index 000000000..bff268a11
--- /dev/null
+++ b/docs/apis/brainpy.rst
@@ -0,0 +1,81 @@
+``brainpy`` module
+==================
+
+.. currentmodule:: brainpy
+.. automodule:: brainpy
+
+.. contents::
+ :local:
+ :depth: 1
+
+Numerical Differential Integration
+----------------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Integrator
+ JointEq
+ IntegratorRunner
+ odeint
+ sdeint
+ fdeint
+
+
+Building Dynamical System
+-------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ DynamicalSystem
+ DynSysGroup
+ Sequential
+ Network
+ Dynamic
+ Projection
+
+
+Simulating Dynamical System
+---------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ DSRunner
+
+
+Training Dynamical System
+-------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ DSTrainer
+ BPTT
+ BPFF
+ OnlineTrainer
+ ForceTrainer
+ OfflineTrainer
+ RidgeTrainer
+
+
+Dynamical System Helpers
+------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ LoopOverTime
+
+
diff --git a/docs/apis/connect.rst b/docs/apis/connect.rst
new file mode 100644
index 000000000..9c42fbabb
--- /dev/null
+++ b/docs/apis/connect.rst
@@ -0,0 +1,100 @@
+``brainpy.connect`` module
+==========================
+
+.. currentmodule:: brainpy.connect
+.. automodule:: brainpy.connect
+
+.. contents::
+ :local:
+ :depth: 1
+
+Base Connection Classes and Tools
+---------------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ set_default_dtype
+ get_idx_type
+ mat2coo
+ mat2csc
+ mat2csr
+ csr2csc
+ csr2mat
+ csr2coo
+ coo2csr
+ coo2csc
+ coo2mat
+ coo2mat_num
+ mat2mat_num
+ visualizeMat
+ MAT_DTYPE
+ IDX_DTYPE
+ Connector
+ TwoEndConnector
+ OneEndConnector
+ CONN_MAT
+ PRE_IDS
+ POST_IDS
+ PRE2POST
+ POST2PRE
+ PRE2SYN
+ POST2SYN
+ SUPPORTED_SYN_STRUCTURE
+
+
+Custom Connections
+------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ MatConn
+ IJConn
+ CSRConn
+ SparseMatConn
+
+
+Random Connections
+------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ FixedProb
+ FixedPreNum
+ FixedPostNum
+ FixedTotalNum
+ GaussianProb
+ ProbDist
+ SmallWorld
+ ScaleFreeBA
+ ScaleFreeBADual
+ PowerLaw
+
+
+Regular Connections
+-------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ One2One
+ All2All
+ GridFour
+ GridEight
+ GridN
+ one2one
+ all2all
+ grid_four
+ grid_eight
+
+
diff --git a/docs/apis/channels.rst b/docs/apis/deprecated/channels.rst
similarity index 100%
rename from docs/apis/channels.rst
rename to docs/apis/deprecated/channels.rst
diff --git a/docs/apis/layers.rst b/docs/apis/deprecated/layers.rst
similarity index 100%
rename from docs/apis/layers.rst
rename to docs/apis/deprecated/layers.rst
diff --git a/docs/apis/neurons.rst b/docs/apis/deprecated/neurons.rst
similarity index 100%
rename from docs/apis/neurons.rst
rename to docs/apis/deprecated/neurons.rst
diff --git a/docs/apis/rates.rst b/docs/apis/deprecated/rates.rst
similarity index 100%
rename from docs/apis/rates.rst
rename to docs/apis/deprecated/rates.rst
diff --git a/docs/apis/synapses.rst b/docs/apis/deprecated/synapses.rst
similarity index 100%
rename from docs/apis/synapses.rst
rename to docs/apis/deprecated/synapses.rst
diff --git a/docs/apis/synouts.rst b/docs/apis/deprecated/synouts.rst
similarity index 100%
rename from docs/apis/synouts.rst
rename to docs/apis/deprecated/synouts.rst
diff --git a/docs/apis/synplast.rst b/docs/apis/deprecated/synplast.rst
similarity index 100%
rename from docs/apis/synplast.rst
rename to docs/apis/deprecated/synplast.rst
diff --git a/docs/apis/dnn.rst b/docs/apis/dnn.rst
new file mode 100644
index 000000000..736066ce4
--- /dev/null
+++ b/docs/apis/dnn.rst
@@ -0,0 +1,184 @@
+``brainpy.dnn`` module
+======================
+
+.. currentmodule:: brainpy.dnn
+.. automodule:: brainpy.dnn
+
+.. contents::
+ :local:
+ :depth: 1
+
+Non-linear Activations
+----------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Activation
+ Flatten
+ FunAsLayer
+ Threshold
+ ReLU
+ RReLU
+ Hardtanh
+ ReLU6
+ Sigmoid
+ Hardsigmoid
+ Tanh
+ SiLU
+ Mish
+ Hardswish
+ ELU
+ CELU
+ SELU
+ GLU
+ GELU
+ Hardshrink
+ LeakyReLU
+ LogSigmoid
+ Softplus
+ Softshrink
+ PReLU
+ Softsign
+ Tanhshrink
+ Softmin
+ Softmax
+ Softmax2d
+ LogSoftmax
+
+
+Convolutional Layers
+--------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Conv1d
+ Conv2d
+ Conv3d
+ Conv1D
+ Conv2D
+ Conv3D
+ ConvTranspose1d
+ ConvTranspose2d
+ ConvTranspose3d
+
+
+Dense Connection Layers
+-----------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Dense
+ Linear
+ Identity
+ AllToAll
+ OneToOne
+ MaskedLinear
+ CSRLinear
+ EventCSRLinear
+ JitFPHomoLinear
+ JitFPUniformLinear
+ JitFPNormalLinear
+ EventJitFPHomoLinear
+ EventJitFPNormalLinear
+ EventJitFPUniformLinear
+
+
+Normalization Layers
+--------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ BatchNorm1d
+ BatchNorm2d
+ BatchNorm3d
+ BatchNorm1D
+ BatchNorm2D
+ BatchNorm3D
+ LayerNorm
+ GroupNorm
+ InstanceNorm
+
+
+Pooling Layers
+--------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ MaxPool
+ MaxPool1d
+ MaxPool2d
+ MaxPool3d
+ MinPool
+ AvgPool
+ AvgPool1d
+ AvgPool2d
+ AvgPool3d
+ AdaptiveAvgPool1d
+ AdaptiveAvgPool2d
+ AdaptiveAvgPool3d
+ AdaptiveMaxPool1d
+ AdaptiveMaxPool2d
+ AdaptiveMaxPool3d
+
+
+Artificial Recurrent Layers
+---------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ NVAR
+ Reservoir
+ RNNCell
+ GRUCell
+ LSTMCell
+ Conv1dLSTMCell
+ Conv2dLSTMCell
+ Conv3dLSTMCell
+
+
+Interoperation with Flax
+------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ FromFlax
+ ToFlaxRNNCell
+ ToFlax
+
+
+Other Layers
+------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Layer
+ Dropout
+ Activation
+ Flatten
+ FunAsLayer
+
+
diff --git a/docs/apis/dyn.rst b/docs/apis/dyn.rst
new file mode 100644
index 000000000..0b8a3431e
--- /dev/null
+++ b/docs/apis/dyn.rst
@@ -0,0 +1,18 @@
+``brainpy.dyn`` module
+======================
+
+
+.. toctree::
+ :maxdepth: 1
+
+ brainpy.dyn.base.rst
+ brainpy.dyn.ions.rst
+ brainpy.dyn.channels.rst
+ brainpy.dyn.neurons.rst
+ brainpy.dyn.synapses.rst
+ brainpy.dyn.outs.rst
+ brainpy.dyn.rates.rst
+ brainpy.dyn.projections.rst
+ brainpy.dyn.plasticity.rst
+ brainpy.dyn.others.rst
+
diff --git a/docs/apis/encoding.rst b/docs/apis/encoding.rst
new file mode 100644
index 000000000..23736b1af
--- /dev/null
+++ b/docs/apis/encoding.rst
@@ -0,0 +1,16 @@
+``brainpy.encoding`` module
+===========================
+
+.. currentmodule:: brainpy.encoding
+.. automodule:: brainpy.encoding
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Encoder
+ LatencyEncoder
+ WeightedPhaseEncoder
+ PoissonEncoder
+ DiffEncoder
diff --git a/docs/apis/initialize.rst b/docs/apis/initialize.rst
new file mode 100644
index 000000000..fcce922c8
--- /dev/null
+++ b/docs/apis/initialize.rst
@@ -0,0 +1,68 @@
+``brainpy.initialize`` module
+=============================
+
+.. currentmodule:: brainpy.initialize
+.. automodule:: brainpy.initialize
+
+.. contents::
+ :local:
+ :depth: 1
+
+Basic Initialization Classes
+----------------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Initializer
+
+
+Regular Initializers
+--------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ ZeroInit
+ Constant
+ OneInit
+ Identity
+
+
+Random Initializers
+-------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Normal
+ Uniform
+ VarianceScaling
+ KaimingUniform
+ KaimingNormal
+ XavierUniform
+ XavierNormal
+ LecunUniform
+ LecunNormal
+ Orthogonal
+ DeltaOrthogonal
+
+
+Decay Initializers
+------------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ GaussianDecay
+ DOGDecay
+
+
diff --git a/docs/apis/inputs.rst b/docs/apis/inputs.rst
new file mode 100644
index 000000000..e05372e8c
--- /dev/null
+++ b/docs/apis/inputs.rst
@@ -0,0 +1,17 @@
+``brainpy.inputs`` module
+=========================
+
+.. currentmodule:: brainpy.inputs
+.. automodule:: brainpy.inputs
+
+.. autosummary::
+ :toctree: generated/
+
+ section_input
+ constant_input
+ spike_input
+ ramp_input
+ wiener_process
+ ou_process
+ sinusoidal_input
+ square_input
diff --git a/docs/apis/integrators.rst b/docs/apis/integrators.rst
new file mode 100644
index 000000000..187b4e9a4
--- /dev/null
+++ b/docs/apis/integrators.rst
@@ -0,0 +1,205 @@
+``brainpy.integrators`` module
+==============================
+
+.. currentmodule:: brainpy.integrators
+.. automodule:: brainpy.integrators
+
+.. contents::
+ :local:
+ :depth: 2
+
+ODE integrators
+---------------
+
+.. currentmodule:: brainpy.integrators.ode
+.. automodule:: brainpy.integrators.ode
+
+Base ODE Integrator
+~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ ODEIntegrator
+
+
+Generic ODE Functions
+~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ set_default_odeint
+ get_default_odeint
+ register_ode_integrator
+ get_supported_methods
+
+
+Explicit Runge-Kutta ODE Integrators
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ ExplicitRKIntegrator
+ Euler
+ MidPoint
+ Heun2
+ Ralston2
+ RK2
+ RK3
+ Heun3
+ Ralston3
+ SSPRK3
+ RK4
+ Ralston4
+ RK4Rule38
+
+
+Adaptive Runge-Kutta ODE Integrators
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ AdaptiveRKIntegrator
+ RKF12
+ RKF45
+ DormandPrince
+ CashKarp
+ BogackiShampine
+ HeunEuler
+
+
+Exponential ODE Integrators
+~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ ExponentialEuler
+
+
+SDE integrators
+---------------
+
+.. currentmodule:: brainpy.integrators.sde
+.. automodule:: brainpy.integrators.sde
+
+Base SDE Integrator
+~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ SDEIntegrator
+
+
+Generic SDE Functions
+~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ set_default_sdeint
+ get_default_sdeint
+ register_sde_integrator
+ get_supported_methods
+
+
+Normal SDE Integrators
+~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Euler
+ Heun
+ Milstein
+ MilsteinGradFree
+ ExponentialEuler
+
+
+SRK methods for scalar Wiener process
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ SRK1W1
+ SRK2W1
+ KlPl
+
+
+FDE integrators
+---------------
+
+.. currentmodule:: brainpy.integrators.fde
+.. automodule:: brainpy.integrators.fde
+
+Base FDE Integrator
+~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ FDEIntegrator
+
+
+Generic FDE Functions
+~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ set_default_fdeint
+ get_default_fdeint
+ register_fde_integrator
+ get_supported_methods
+
+
+Methods for Caputo Fractional Derivative
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ CaputoEuler
+ CaputoL1Schema
+
+
+Methods for Riemann-Liouville Fractional Derivative
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ GLShortMemory
+
+
diff --git a/docs/apis/losses.rst b/docs/apis/losses.rst
new file mode 100644
index 000000000..8f50c487f
--- /dev/null
+++ b/docs/apis/losses.rst
@@ -0,0 +1,57 @@
+``brainpy.losses`` module
+=========================
+
+.. currentmodule:: brainpy.losses
+.. automodule:: brainpy.losses
+
+.. contents::
+ :local:
+ :depth: 1
+
+Comparison
+----------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ cross_entropy_loss
+ cross_entropy_sparse
+ cross_entropy_sigmoid
+ nll_loss
+ l1_loss
+ l2_loss
+ huber_loss
+ mean_absolute_error
+ mean_squared_error
+ mean_squared_log_error
+ binary_logistic_loss
+ multiclass_logistic_loss
+ sigmoid_binary_cross_entropy
+ softmax_cross_entropy
+ log_cosh_loss
+ ctc_loss_with_forward_probs
+ ctc_loss
+ CrossEntropyLoss
+ NLLLoss
+ L1Loss
+ MAELoss
+ MSELoss
+
+
+Regularization
+--------------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ l2_norm
+ mean_absolute
+ mean_square
+ log_cosh
+ smooth_labels
+
+
diff --git a/docs/apis/math.rst b/docs/apis/math.rst
new file mode 100644
index 000000000..92e4f56fc
--- /dev/null
+++ b/docs/apis/math.rst
@@ -0,0 +1,480 @@
+``brainpy.math`` module
+=======================
+
+.. contents::
+ :local:
+ :depth: 1
+
+Objects and Variables
+---------------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ BrainPyObject
+ FunAsObject
+ Partial
+ NodeList
+ NodeDict
+ node_dict
+ node_list
+ Variable
+ Parameter
+ TrainVar
+ VariableView
+ VarList
+ VarDict
+ var_list
+ var_dict
+
+
+Object-oriented Transformations
+-------------------------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ grad
+ vector_grad
+ jacobian
+ jacrev
+ jacfwd
+ hessian
+ make_loop
+ make_while
+ make_cond
+ cond
+ ifelse
+ for_loop
+ while_loop
+ jit
+ cls_jit
+ to_object
+ function
+
+
+Environment Settings
+--------------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ set
+ set_environment
+ set_float
+ get_float
+ set_int
+ get_int
+ set_bool
+ get_bool
+ set_complex
+ get_complex
+ set_dt
+ get_dt
+ set_mode
+ get_mode
+ enable_x64
+ disable_x64
+ set_platform
+ get_platform
+ set_host_device_count
+ clear_buffer_memory
+ enable_gpu_memory_preallocation
+ disable_gpu_memory_preallocation
+ ditype
+ dftype
+ environment
+ batching_environment
+ training_environment
+
+
+Array Interoperability
+----------------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ as_device_array
+ as_jax
+ as_ndarray
+ as_numpy
+ as_variable
+
+
+Operators for Pre-Syn-Post Conversion
+-------------------------------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ pre2post_sum
+ pre2post_prod
+ pre2post_max
+ pre2post_min
+ pre2post_mean
+ pre2post_event_sum
+ pre2post_csr_event_sum
+ pre2post_coo_event_sum
+ pre2syn
+ syn2post_sum
+ syn2post
+ syn2post_prod
+ syn2post_max
+ syn2post_min
+ syn2post_mean
+ syn2post_softmax
+
+
+Activation Functions
+--------------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ celu
+ elu
+ gelu
+ glu
+ prelu
+ silu
+ selu
+ relu
+ relu6
+ rrelu
+ hard_silu
+ leaky_relu
+ hard_tanh
+ hard_sigmoid
+ tanh_shrink
+ hard_swish
+ hard_shrink
+ soft_sign
+ soft_shrink
+ softmax
+ softmin
+ softplus
+ swish
+ mish
+ log_sigmoid
+ log_softmax
+ one_hot
+ normalize
+ sigmoid
+ identity
+ tanh
+
+
+Delay Variables
+---------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ TimeDelay
+ LengthDelay
+ NeuTimeDelay
+ NeuLenDelay
+ ROTATE_UPDATE
+ CONCAT_UPDATE
+
+
+Computing Modes
+---------------
+
+.. currentmodule:: brainpy.math
+.. automodule:: brainpy.math
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Mode
+ NonBatchingMode
+ BatchingMode
+ TrainingMode
+ nonbatching_mode
+ batching_mode
+ training_mode
+
+
+``brainpy.math.sparse`` module: Sparse Operators
+------------------------------------------------
+
+.. currentmodule:: brainpy.math.sparse
+.. automodule:: brainpy.math.sparse
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ csrmv
+ coomv
+ seg_matmul
+ csr_to_dense
+ csr_to_coo
+ coo_to_csr
+
+
+``brainpy.math.event`` module: Event-driven Operators
+-----------------------------------------------------
+
+.. currentmodule:: brainpy.math.event
+.. automodule:: brainpy.math.event
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ csrmv
+ info
+
+
+``brainpy.math.jitconn`` module: Just-In-Time Connectivity Operators
+--------------------------------------------------------------------
+
+.. currentmodule:: brainpy.math.jitconn
+.. automodule:: brainpy.math.jitconn
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ event_mv_prob_homo
+ event_mv_prob_uniform
+ event_mv_prob_normal
+ mv_prob_homo
+ mv_prob_uniform
+ mv_prob_normal
+
+
+``brainpy.math.surrogate`` module: Surrogate Gradient Functions
+---------------------------------------------------------------
+
+.. currentmodule:: brainpy.math.surrogate
+.. automodule:: brainpy.math.surrogate
+
+.. autosummary::
+ :toctree: generated/
+
+ Surrogate
+ Sigmoid
+ sigmoid
+ PiecewiseQuadratic
+ piecewise_quadratic
+ PiecewiseExp
+ piecewise_exp
+ SoftSign
+ soft_sign
+ Arctan
+ arctan
+ NonzeroSignLog
+ nonzero_sign_log
+ ERF
+ erf
+ PiecewiseLeakyRelu
+ piecewise_leaky_relu
+ SquarewaveFourierSeries
+ squarewave_fourier_series
+ S2NN
+ s2nn
+ QPseudoSpike
+ q_pseudo_spike
+ LeakyRelu
+ leaky_relu
+ LogTailedRelu
+ log_tailed_relu
+ ReluGrad
+ relu_grad
+ GaussianGrad
+ gaussian_grad
+ InvSquareGrad
+ inv_square_grad
+ MultiGaussianGrad
+ multi_gaussian_grad
+ SlayerGrad
+ slayer_grad
+ inv_square_grad2
+ relu_grad2
+
+
+
+``brainpy.math.random`` module: Random Number Generations
+---------------------------------------------------------
+
+.. currentmodule:: brainpy.math.random
+.. automodule:: brainpy.math.random
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ seed
+ split_key
+ split_keys
+ default_rng
+ rand
+ randint
+ random_integers
+ randn
+ random
+ random_sample
+ ranf
+ sample
+ choice
+ permutation
+ shuffle
+ beta
+ exponential
+ gamma
+ gumbel
+ laplace
+ logistic
+ normal
+ pareto
+ poisson
+ standard_cauchy
+ standard_exponential
+ standard_gamma
+ standard_normal
+ standard_t
+ uniform
+ truncated_normal
+ bernoulli
+ lognormal
+ binomial
+ chisquare
+ dirichlet
+ geometric
+ f
+ hypergeometric
+ logseries
+ multinomial
+ multivariate_normal
+ negative_binomial
+ noncentral_chisquare
+ noncentral_f
+ power
+ rayleigh
+ triangular
+ vonmises
+ wald
+ weibull
+ weibull_min
+ zipf
+ maxwell
+ t
+ orthogonal
+ loggamma
+ categorical
+ rand_like
+ randint_like
+ randn_like
+ RandomState
+ Generator
+ DEFAULT
+
+
+``brainpy.math.linalg`` module: Linear algebra
+----------------------------------------------
+
+.. currentmodule:: brainpy.math.linalg
+.. automodule:: brainpy.math.linalg
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ cholesky
+ cond
+ det
+ eig
+ eigh
+ eigvals
+ eigvalsh
+ inv
+ svd
+ lstsq
+ matrix_power
+ matrix_rank
+ norm
+ pinv
+ qr
+ solve
+ slogdet
+ tensorinv
+ tensorsolve
+ multi_dot
+
+
+``brainpy.math.fft`` module: Discrete Fourier Transform
+-------------------------------------------------------
+
+.. currentmodule:: brainpy.math.fft
+.. automodule:: brainpy.math.fft
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ fft
+ fft2
+ fftfreq
+ fftn
+ fftshift
+ hfft
+ ifft
+ ifft2
+ ifftn
+ ifftshift
+ ihfft
+ irfft
+ irfft2
+ irfftn
+ rfft
+ rfft2
+ rfftfreq
+ rfftn
+
+
diff --git a/docs/apis/measure.rst b/docs/apis/measure.rst
new file mode 100644
index 000000000..931e53947
--- /dev/null
+++ b/docs/apis/measure.rst
@@ -0,0 +1,19 @@
+``brainpy.measure`` module
+==========================
+
+.. currentmodule:: brainpy.measure
+.. automodule:: brainpy.measure
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ cross_correlation
+ voltage_fluctuation
+ matrix_correlation
+ weighted_correlation
+ functional_connectivity
+ raster_plot
+ firing_rate
+ unitary_LFP
diff --git a/docs/apis/mixin.rst b/docs/apis/mixin.rst
new file mode 100644
index 000000000..d797bb37a
--- /dev/null
+++ b/docs/apis/mixin.rst
@@ -0,0 +1,22 @@
+``brainpy.mixin`` module
+========================
+
+.. currentmodule:: brainpy.mixin
+.. automodule:: brainpy.mixin
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+
+ MixIn
+ ReceiveInputProj
+ AlignPost
+ AutoDelaySupp
+ ParamDesc
+ ParamDescInit
+ BindCondData
+ Container
+ TreeNode
+ JointType
diff --git a/docs/apis/optim.rst b/docs/apis/optim.rst
new file mode 100644
index 000000000..49b09e594
--- /dev/null
+++ b/docs/apis/optim.rst
@@ -0,0 +1,63 @@
+``brainpy.optim`` module
+========================
+
+.. currentmodule:: brainpy.optim
+.. automodule:: brainpy.optim
+
+.. contents::
+ :local:
+ :depth: 1
+
+Optimizers
+----------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ Optimizer
+ SGD
+ Momentum
+ MomentumNesterov
+ Adagrad
+ Adadelta
+ RMSProp
+ Adam
+ LARS
+ Adan
+ AdamW
+
+
+Schedulers
+----------
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ make_schedule
+ partial
+ BrainPyObject
+ MathError
+ Scheduler
+ Constant
+ CallBasedScheduler
+ StepLR
+ MultiStepLR
+ CosineAnnealingLR
+ CosineAnnealingWarmRestarts
+ ExponentialLR
+ ExponentialDecayLR
+ ExponentialDecay
+ InverseTimeDecayLR
+ InverseTimeDecay
+ PolynomialDecayLR
+ PolynomialDecay
+ PiecewiseConstantLR
+ PiecewiseConstant
+ Sequence
+ Union
+
+
diff --git a/docs/apis/running.rst b/docs/apis/running.rst
new file mode 100644
index 000000000..aa46ca6d7
--- /dev/null
+++ b/docs/apis/running.rst
@@ -0,0 +1,17 @@
+``brainpy.running`` module
+==========================
+
+.. currentmodule:: brainpy.running
+.. automodule:: brainpy.running
+
+.. autosummary::
+ :toctree: generated/
+ :nosignatures:
+ :template: classtemplate.rst
+
+ jax_vectorize_map
+ jax_parallelize_map
+ process_pool
+ process_pool_lock
+ cpu_ordered_parallel
+ cpu_unordered_parallel
diff --git a/docs/auto_generater.py b/docs/auto_generater.py
index 3cccc347f..cbbb06df1 100644
--- a/docs/auto_generater.py
+++ b/docs/auto_generater.py
@@ -43,7 +43,7 @@ def _write_module(module_name, filename, header=None, template=False):
# write autosummary
fout.write('.. autosummary::\n')
if template:
- fout.write(' :template: class_template.rst\n')
+ fout.write(' :template: classtemplate.rst\n')
fout.write(' :toctree: generated/\n\n')
for m in functions:
fout.write(f' {m}\n')
@@ -77,7 +77,9 @@ def _write_submodules(module_name, filename, header=None, submodule_names=(), se
# write autosummary
fout.write('.. autosummary::\n')
- fout.write(' :toctree: generated/\n\n')
+ fout.write(' :toctree: generated/\n')
+ fout.write(' :nosignatures:\n')
+ fout.write(' :template: classtemplate.rst\n\n')
for m in functions:
fout.write(f' {m}\n')
for m in classes:
@@ -109,7 +111,9 @@ def _write_subsections(module_name,
fout.write(name + '\n')
fout.write('-' * len(name) + '\n\n')
fout.write('.. autosummary::\n')
- fout.write(' :toctree: generated/\n\n')
+ fout.write(' :toctree: generated/\n')
+ fout.write(' :nosignatures:\n')
+ fout.write(' :template: classtemplate.rst\n\n')
for m in values:
fout.write(f' {m}\n')
fout.write(f'\n\n')
@@ -140,7 +144,9 @@ def _write_subsections_v2(module_path,
fout.write(subheader + '\n')
fout.write('-' * len(subheader) + '\n\n')
fout.write('.. autosummary::\n')
- fout.write(' :toctree: generated/\n\n')
+ fout.write(' :toctree: generated/\n')
+ fout.write(' :nosignatures:\n')
+ fout.write(' :template: classtemplate.rst\n\n')
for m in functions:
fout.write(f' {m}\n')
for m in classes:
@@ -182,7 +188,9 @@ def _write_subsections_v3(module_path,
fout.write(subheader + '\n')
fout.write('~' * len(subheader) + '\n\n')
fout.write('.. autosummary::\n')
- fout.write(' :toctree: generated/\n\n')
+ fout.write(' :toctree: generated/\n')
+ fout.write(' :nosignatures:\n')
+ fout.write(' :template: classtemplate.rst\n\n')
for m in functions:
fout.write(f' {m}\n')
for m in classes:
@@ -220,7 +228,9 @@ def _write_subsections_v4(module_path,
fout.write('.. autosummary::\n')
- fout.write(' :toctree: generated/\n\n')
+ fout.write(' :toctree: generated/\n')
+ fout.write(' :nosignatures:\n')
+ fout.write(' :template: classtemplate.rst\n\n')
for m in functions:
fout.write(f' {m}\n')
for m in classes:
diff --git a/docs/conf.py b/docs/conf.py
index 8853c8b1f..19b1ab5bc 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -18,25 +18,26 @@
sys.path.insert(0, os.path.abspath('../'))
import brainpy
-from docs import auto_generater
os.makedirs('apis/auto/', exist_ok=True)
-auto_generater.generate_analysis_docs()
-auto_generater.generate_connect_docs()
-auto_generater.generate_encoding_docs()
-auto_generater.generate_initialize_docs()
-auto_generater.generate_inputs_docs()
-auto_generater.generate_dnn_docs()
-auto_generater.generate_dyn_docs()
-auto_generater.generate_losses_docs()
-auto_generater.generate_measure_docs()
-auto_generater.generate_optim_docs()
-auto_generater.generate_running_docs()
-auto_generater.generate_brainpy_docs()
-auto_generater.generate_integrators_doc()
-auto_generater.generate_math_docs()
-auto_generater.generate_mixin_docs()
+# from docs import auto_generater
+# auto_generater.generate_analysis_docs()
+# auto_generater.generate_connect_docs()
+# auto_generater.generate_encoding_docs()
+# auto_generater.generate_initialize_docs()
+# auto_generater.generate_inputs_docs()
+# auto_generater.generate_dnn_docs()
+# auto_generater.generate_dyn_docs()
+# auto_generater.generate_losses_docs()
+# auto_generater.generate_measure_docs()
+# auto_generater.generate_optim_docs()
+# auto_generater.generate_running_docs()
+# auto_generater.generate_brainpy_docs()
+# auto_generater.generate_integrators_doc()
+# auto_generater.generate_math_docs()
+# auto_generater.generate_mixin_docs()
+# sys.exit()
changelogs = [
('../changelog.rst', 'apis/auto/changelog.rst'),
diff --git a/docs/core_concept/brainpy_transform_concept-old.ipynb b/docs/core_concept/brainpy_transform_concept-old.ipynb
deleted file mode 100644
index c8b3a771b..000000000
--- a/docs/core_concept/brainpy_transform_concept-old.ipynb
+++ /dev/null
@@ -1,654 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "collapsed": true,
- "jupyter": {
- "outputs_hidden": true
- }
- },
- "source": [
- "# Concept 1: Object-oriented Transformation"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "@[Chaoming Wang](https://github.com/chaoming0625)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Most computation in BrainPy relies on [JAX](https://jax.readthedocs.io/en/latest/).\n",
- "JAX has provided wonderful transformations, including differentiation, vecterization, parallelization and just-in-time compilation, for Python programs. If you are not familiar with it, please see its [documentation](https://jax.readthedocs.io/en/latest/).\n",
- "\n",
- "However, JAX only supports functional programming, i.e., transformations for Python functions. This is not what we want. Brain Dynamics Modeling need object-oriented programming.\n",
- "\n",
- "To meet this requirement, BrainPy defines the interface for object-oriented (OO) transformations. These OO transformations can be easily performed for BrainPy objects.\n",
- "\n",
- "In this section, let's talk about the BrainPy concept of object-oriented transformations."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "import brainpy as bp\n",
- "import brainpy.math as bm\n",
- "\n",
- "# bm.set_platform('cpu')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "collapsed": false,
- "jupyter": {
- "outputs_hidden": false
- }
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'2.3.0'"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "bp.__version__"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Illustrating example: Training a network"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "To illustrate this concept, we need a demonstration example. Here, we choose the popular neural network training as the illustrating case.\n",
- "\n",
- "In this training case, we want to teach the neural network to correctly classify a random array as two labels (`True` or `False`). That is, we have the training data:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "num_in = 100\n",
- "num_sample = 256\n",
- "X = bm.random.rand(num_sample, num_in)\n",
- "Y = (bm.random.rand(num_sample) < 0.5).astype(float)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "We use a two-layer feedforward network:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Sequential(\n",
- " [0] Linear0\n",
- " [1] relu\n",
- " [2] Linear1\n",
- ")\n"
- ]
- }
- ],
- "source": [
- "class Linear(bp.BrainPyObject):\n",
- " def __init__(self, n_in, n_out):\n",
- " super().__init__()\n",
- " self.num_in = n_in\n",
- " self.num_out = n_out\n",
- " init = bp.init.XavierNormal()\n",
- " self.W = bm.Variable(init((n_in, n_out)))\n",
- " self.b = bm.Variable(bm.zeros((1, n_out)))\n",
- "\n",
- " def __call__(self, x):\n",
- " return x @ self.W + self.b\n",
- "\n",
- "\n",
- "net = bp.Sequential(Linear(num_in, 20),\n",
- " bm.relu,\n",
- " Linear(20, 2))\n",
- "print(net)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Here, we use a supervised learning training paradigm. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Train 400 epoch, loss = 0.6710\n",
- "Train 800 epoch, loss = 0.5992\n",
- "Train 1200 epoch, loss = 0.5332\n",
- "Train 1600 epoch, loss = 0.4720\n",
- "Train 2000 epoch, loss = 0.4189\n",
- "Train 2400 epoch, loss = 0.3736\n",
- "Train 2800 epoch, loss = 0.3335\n",
- "Train 3200 epoch, loss = 0.2972\n",
- "Train 3600 epoch, loss = 0.2644\n",
- "Train 4000 epoch, loss = 0.2346\n"
- ]
- }
- ],
- "source": [
- "rng = bm.random.RandomState(123)\n",
- "\n",
- "\n",
- "# Loss function\n",
- "@bm.to_object(child_objs=net, dyn_vars=rng)\n",
- "def loss():\n",
- " # shuffle the data\n",
- " key = rng.split_key()\n",
- " x_data = rng.permutation(X, key=key)\n",
- " y_data = rng.permutation(Y, key=key)\n",
- " # prediction\n",
- " predictions = net(dict(), x_data)\n",
- " # loss\n",
- " l = bp.losses.cross_entropy_loss(predictions, y_data)\n",
- " return l\n",
- "\n",
- "\n",
- "# Gradient function\n",
- "grad = bm.grad(loss, grad_vars=net.vars(), return_value=True)\n",
- "\n",
- "# Optimizer\n",
- "optimizer = bp.optim.SGD(lr=1e-2, train_vars=net.vars())\n",
- "\n",
- "\n",
- "# Training step\n",
- "@bm.to_object(child_objs=(grad, optimizer))\n",
- "def train(i):\n",
- " grads, l = grad()\n",
- " optimizer.update(grads)\n",
- " return l\n",
- "\n",
- "\n",
- "num_step = 400\n",
- "for i in range(0, 4000, num_step):\n",
- " # train 400 steps once\n",
- " ls = bm.for_loop(train, operands=bm.arange(i, i + num_step))\n",
- " print(f'Train {i + num_step} epoch, loss = {bm.mean(ls):.4f}')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In the above example, we have seen classical elements in a neural network training, such as \n",
- "\n",
- "- `net`: neural network\n",
- "- `loss`: loss function\n",
- "- `grad`: gradient function\n",
- "- `optimizer`: parameter optimizer\n",
- "- `train`: training step"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In BrainPy, all these elements can be defined as class objects and can be used for performing OO transformations. "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In essence, the concept of BrainPy object-oriented transformation has three components:\n",
- "\n",
- "- `BrainPyObject`: the base class for object-oriented programming\n",
- "- `Variable`: the varibles in the class object, whose values are ready to be changed/updated during transformation\n",
- "- `ObjectTransform`: the transformations for computation involving `BrainPyObject` and `Variable`"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## ``BrainPyObject`` and its ``Variable``"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "``BrainPyObject`` is the base class for object-oriented programming in BrainPy. \n",
- "It can be viewed as a container which contains all needed [Variable](../tutorial_math/arrays_and_variables.ipynb) for our computation."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "![](./imgs/net_with_two_linear.png)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In the above example, ``Linear`` object has two ``Variable``: *W* and *b*. The ``net`` we defined is further composed of two ``Linear`` objects. We can expect that four variables can be retrieved from it."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "dict_keys(['Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "net.vars().keys()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "An important question is, **how to define `Variable` in a `BrainPyObject` so that we can retrieve all of them?**"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Actually, all Variable instance which can be accessed by `self.` attribue can be retrived from a `BrainPyObject` recursively. \n",
- "No matter how deep the composition of ``BrainPyObject``, once `BrainPyObject` instance and their `Variable` instances can be accessed by `self.` operation, all of them will be retrieved. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "class SuperLinear(bp.BrainPyObject):\n",
- " def __init__(self, ):\n",
- " super().__init__()\n",
- " self.l1 = Linear(10, 20)\n",
- " self.v1 = bm.Variable(3)\n",
- " \n",
- "sl = SuperLinear()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "dict_keys(['SuperLinear0.v1', 'Linear2.W', 'Linear2.b'])"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# retrieve Variable\n",
- "sl.vars().keys()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "dict_keys(['SuperLinear0', 'Linear2'])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "# retrieve BrainPyObject\n",
- "sl.nodes().keys()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "However, we cannot access the ``BrainPyObject`` or ``Variable`` which is in a Python container (like tuple, list, or dict). For this case, we can register our objects and variables through ``.register_implicit_vars()`` and ``.register_implicit_nodes()``:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [],
- "source": [
- "class SuperSuperLinear(bp.BrainPyObject):\n",
- " def __init__(self, register=False):\n",
- " super().__init__()\n",
- " self.ss = [SuperLinear(), SuperLinear()]\n",
- " self.vv = {'v_a': bm.Variable(3)}\n",
- " if register:\n",
- " self.register_implicit_nodes(self.ss)\n",
- " self.register_implicit_vars(self.vv)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "dict_keys([])\n",
- "dict_keys(['SuperSuperLinear0'])\n"
- ]
- }
- ],
- "source": [
- "# without register\n",
- "ssl = SuperSuperLinear(register=False)\n",
- "print(ssl.vars().keys())\n",
- "print(ssl.nodes().keys())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "dict_keys(['SuperSuperLinear1.v_a', 'SuperLinear3.v1', 'SuperLinear4.v1', 'Linear5.W', 'Linear5.b', 'Linear6.W', 'Linear6.b'])\n",
- "dict_keys(['SuperSuperLinear1', 'SuperLinear3', 'SuperLinear4', 'Linear5', 'Linear6'])\n"
- ]
- }
- ],
- "source": [
- "# with register\n",
- "ssl = SuperSuperLinear(register=True)\n",
- "print(ssl.vars().keys())\n",
- "print(ssl.nodes().keys())"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Transform a function to `BrainPyObject`"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "![](./imgs/loss_with_net_and_rng.png)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Let's go back to our network training.\n",
- "After the definition of `net`, we further define a ``loss`` function whose computation involves the ``net`` object for neural network prediction and a ``rng`` Variable for data shuffling. \n",
- "\n",
- "This Python function is then transformed into a ``BrainPyObject`` instance by ``brainpy.math.to_object`` interface. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "FunAsObject(nodes=[Sequential0],\n",
- " num_of_vars=1)"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "loss"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "All `Variable` used in this instance can also be retrieved through:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "dict_keys(['loss0._var0', 'Linear0.W', 'Linear0.b', 'Linear1.W', 'Linear1.b'])"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "loss.vars().keys()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note that, when using `to_object()`, we need to explicitly declare all `BrainPyObject` and `Variable` used in this Python function. \n",
- "Due to the recursive retrieval property of `BrainPyObject`, we only need to specify the latest composition object.\n",
- "\n",
- "In the above `loss` object, we do not need to specify two ``Linear`` object. Instead, we only need to give the top level object ``net`` into ``to_object()`` transform. \n",
- "\n",
- "Similarly, when we transform ``train`` function into a ``BrainPyObject``, we just need to point out the ``grad`` and ``opt`` we have used, rather than the previous *loss*, *net* or *rng*. "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "![](./imgs/train_with_grad_and_opt.png)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## BrainPy object-oriented transformations"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "BrainPy object-oriented transformations are designed to work on ``BrainPyObject``. \n",
- "These transformations include autograd ``brainpy.math.grad()`` and JIT ``brainpy.math.jit()``."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In our case, we used two OO transformations provided in BrainPy. \n",
- "\n",
- "First, ``grad`` object is defined with the ``loss`` function. Within it, we need to specify what variables we need to compute their gradients through `grad_vars`. "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Note that, the OO transformation of any ``BrainPyObject`` results in another ``BrainPyObject`` object. Therefore, it can be recersively used as a component to form the larger scope of object-oriented programming and object-oriented transformation. "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "GradientTransform(target=loss0, \n",
- " num_of_grad_vars=4, \n",
- " num_of_dyn_vars=1)"
- ]
- },
- "execution_count": 18,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "grad"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "![](./imgs/grad_with_loss.png)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Next, we train 400 steps once by using a ``for_loop`` transformation. Different from ``grad`` which return a `BrainPyObject` instance, `for_loop` direactly returns the loop results. "
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "![](./imgs/for-loop-train.png)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "BrainPy",
- "language": "python",
- "name": "brainpy"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.6"
- },
- "latex_envs": {
- "LaTeX_envs_menu_present": true,
- "autoclose": false,
- "autocomplete": true,
- "bibliofile": "biblio.bib",
- "cite_by": "apalike",
- "current_citInitial": 1,
- "eqLabelWithNumbers": true,
- "eqNumInitial": 1,
- "hotkeys": {
- "equation": "Ctrl-E",
- "itemize": "Ctrl-I"
- },
- "labels_anchors": false,
- "latex_user_defs": false,
- "report_style_numbering": false,
- "user_envs_cfg": false
- },
- "toc": {
- "base_numbering": 1,
- "nav_menu": {},
- "number_sections": false,
- "sideBar": true,
- "skip_h1_title": false,
- "title_cell": "Table of Contents",
- "title_sidebar": "Contents",
- "toc_cell": false,
- "toc_position": {
- "height": "calc(100% - 180px)",
- "left": "10px",
- "top": "150px",
- "width": "245.75px"
- },
- "toc_section_display": true,
- "toc_window_display": true
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/docs/core_concept/imgs/for-loop-train.png b/docs/core_concept/imgs/for-loop-train.png
deleted file mode 100644
index 5c380e5a7..000000000
Binary files a/docs/core_concept/imgs/for-loop-train.png and /dev/null differ
diff --git a/docs/core_concept/imgs/grad_with_loss.png b/docs/core_concept/imgs/grad_with_loss.png
deleted file mode 100644
index 64e7d6ab9..000000000
Binary files a/docs/core_concept/imgs/grad_with_loss.png and /dev/null differ
diff --git a/docs/core_concept/imgs/loss_with_net_and_rng.png b/docs/core_concept/imgs/loss_with_net_and_rng.png
deleted file mode 100644
index 94e4b2af5..000000000
Binary files a/docs/core_concept/imgs/loss_with_net_and_rng.png and /dev/null differ
diff --git a/docs/core_concept/imgs/train_with_grad_and_opt.png b/docs/core_concept/imgs/train_with_grad_and_opt.png
deleted file mode 100644
index e5ff0ca30..000000000
Binary files a/docs/core_concept/imgs/train_with_grad_and_opt.png and /dev/null differ
diff --git a/docs/index.rst b/docs/index.rst
index 1cf3db2f3..583a30e08 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -195,8 +195,8 @@ Learn more
core_concepts.rst
tutorials.rst
- advanced_tutorials.rst
toolboxes.rst
+ advanced_tutorials.rst
FAQ.rst
api.rst
diff --git a/docs/quickstart/installation.rst b/docs/quickstart/installation.rst
index 346acfaca..68baef1ad 100644
--- a/docs/quickstart/installation.rst
+++ b/docs/quickstart/installation.rst
@@ -78,8 +78,8 @@ BrainPy relies on `JAX`_. JAX is a high-performance JIT compiler which enables
users to run Python code on CPU, GPU, and TPU devices. Core functionalities of
BrainPy (>=2.0.0) have been migrated to the JAX backend.
-Linux & MacOS
-^^^^^^^^^^^^^
+Linux
+^^^^^
Currently, JAX supports **Linux** (Ubuntu 16.04 or later) and **macOS** (10.12 or
later) platforms. The provided binary releases of `jax` and `jaxlib` for Linux and macOS
@@ -93,14 +93,20 @@ If you want to install a CPU-only version of `jax` and `jaxlib`, you can run
.. code-block:: bash
- pip install --upgrade "jax[cpu]" -f https://storage.googleapis.com/jax-releases/jax_releases.html
+ pip install --upgrade "jax[cpu]"
If you want to install JAX with both CPU and NVidia GPU support, you must first install
`CUDA`_ and `CuDNN`_, if they have not already been installed. Next, run
.. code-block:: bash
- pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+ # CUDA 12 installation
+ # Note: wheels only available on linux.
+ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+
+ # CUDA 11 installation
+ # Note: wheels only available on linux.
+ pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Alternatively, you can download the preferred release ".whl" file for jaxlib
@@ -108,23 +114,54 @@ from the above release links, and install it via ``pip``:
.. code-block:: bash
- pip install xxx-0.3.14-xxx.whl
+ pip install xxx-0.4.15-xxx.whl
- pip install jax==0.3.14
+ pip install jax==0.4.15
.. note::
- Note that the versions of `jaxlib` and `jax` should be consistent.
+ Note that the versions of jaxlib and jax should be consistent.
+
+ For example, if you are using jax==0.4.15, you would better install jax==0.4.15.
+
+
+MacOS
+^^^^^
+
+If you are using macOS Intel, we recommend you first to install the Miniconda Intel installer:
+
+1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.pkg
+2. Then click the downloaded package and install it.
- For example, if you are using `jax==0.3.14`, you would better install `jax==0.3.14`.
+
+If you are using the latest M1 macOS version, you'd better to install the Miniconda M1 installer:
+
+
+1. Download the package in the link https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.pkg
+2. Then click the downloaded package and install it.
+
+
+Finally, you can install `jax` and `jaxlib` as the same as the Linux platform.
+
+.. code-block:: bash
+
+ pip install --upgrade "jax[cpu]"
Windows
^^^^^^^
-For **Windows** users, `jax` and `jaxlib` can be installed from the community supports.
-Specifically, you can install `jax` and `jaxlib` through:
+For **Windows** users with Python >= 3.9, `jax` and `jaxlib` can be installed
+directly from the PyPi channel.
+
+.. code-block:: bash
+
+ pip install jax jaxlib
+
+
+For **Windows** users with Python <= 3.8, `jax` and `jaxlib` can be installed
+from the community supports. Specifically, you can install `jax` and `jaxlib` through:
.. code-block:: bash
@@ -137,14 +174,15 @@ If you are using GPU, you can install GPU-versioned wheels through:
pip install "jax[cuda111]" -f https://whls.blob.core.windows.net/unstable/index.html
Alternatively, you can manually install you favourite version of `jax` and `jaxlib` by
-downloading binary releases of JAX for Windows from https://whls.blob.core.windows.net/unstable/index.html .
+downloading binary releases of JAX for Windows from
+https://whls.blob.core.windows.net/unstable/index.html .
Then install it via ``pip``:
.. code-block:: bash
- pip install xxx-0.3.14-xxx.whl
+ pip install xxx-0.4.15-xxx.whl
- pip install jax==0.3.14
+ pip install jax==0.4.15
WSL
^^^
@@ -159,17 +197,60 @@ Dependency 3: brainpylib
------------------------
Many customized operators in BrainPy are implemented in ``brainpylib``.
-``brainpylib`` can also be installed through `pypi `_.
+``brainpylib`` can also be installed from pypi according to your devices.
+For windows, Linux and MacOS users, ``brainpylib`` supports CPU operators.
+You can install CPU-version `brainpylib` by:
.. code-block:: bash
- pip install brainpylib
+ # CPU installation
+ pip install --upgrade brainpylib
-For windows, Linux and MacOS users, ``brainpylib`` supports CPU operators.
+For Nvidia GPU users, ``brainpylib`` only support Linux system and WSL2 subsystem. You can install the CUDA-version by using:
+
+.. code-block:: bash
+
+ # CUDA 12 installation
+ pip install --upgrade brainpylib-cu12x
+
+.. code-block:: bash
+
+ # CUDA 11 installation
+ pip install --upgrade brainpylib-cu11x
+
+Running BrainPy with docker
+------------------------
+
+If you want to use BrainPy in docker, you can use the following command to pull the docker image:
+
+.. code:: bash
+
+ docker pull brainpy/brainpy:latest
+
+You can then run the docker image by:
+
+.. code:: bash
+
+ docker run -it --platform linux/amd64 brainpy/brainpy:latest
+
+Please notice that BrainPy docker image is based on the `ubuntu22.04` image, so it only support CPU version of BrainPy.
+
+
+Running BrainPy online with binder
+----------------------------------
+
+Click on the following link to launch the Binder environment with the
+BrainPy repository:
+
+|image1|
+
+Wait for the Binder environment to build. This might take a few moments.
-For CUDA users, ``brainpylib`` only support GPU on Linux platform. You can install GPU version ``brainpylib``
-on Linux through ``pip install brainpylib`` too.
+Once the environment is ready, you'll be redirected to a Jupyter
+notebook interface within your web browser.
+.. |image1| image:: https://camo.githubusercontent.com/581c077bdbc6ca6899c86d0acc6145ae85e9d80e6f805a1071793dbe48917982/68747470733a2f2f6d7962696e6465722e6f72672f62616467655f6c6f676f2e737667
+ :target: https://mybinder.org/v2/gh/brainpy/BrainPy-binder/main
.. _NumPy: https://numpy.org/
diff --git a/docs/quickstart/simulation.ipynb b/docs/quickstart/simulation.ipynb
index b83f47dc7..32aa7dca3 100644
--- a/docs/quickstart/simulation.ipynb
+++ b/docs/quickstart/simulation.ipynb
@@ -28,16 +28,18 @@
},
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 2,
"id": "c4fbe84d",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:21.299843Z",
- "end_time": "2023-04-15T13:35:23.181553Z"
+ "end_time": "2023-09-10T08:44:44.998356100Z",
+ "start_time": "2023-09-10T08:44:43.279558300Z"
}
},
"outputs": [],
"source": [
+ "import numpy as np\n",
+ "\n",
"import brainpy as bp\n",
"import brainpy.math as bm\n",
"\n",
@@ -46,20 +48,20 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"id": "d0b5bce6",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:23.181553Z",
- "end_time": "2023-04-15T13:35:23.197148Z"
+ "end_time": "2023-09-10T08:44:45.015026300Z",
+ "start_time": "2023-09-10T08:44:44.998356100Z"
}
},
"outputs": [
{
"data": {
- "text/plain": "'2.4.0'"
+ "text/plain": "'2.4.4.post3'"
},
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@@ -117,23 +119,23 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"id": "69556409",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:23.197148Z",
- "end_time": "2023-04-15T13:35:23.612974Z"
+ "end_time": "2023-09-10T08:44:45.746060600Z",
+ "start_time": "2023-09-10T08:44:45.017640300Z"
}
},
"outputs": [],
"source": [
- "E = bp.neurons.LIF(3200, V_rest=-60., V_th=-50., V_reset=-60.,\n",
- " tau=20., tau_ref=5., method='exp_auto',\n",
- " V_initializer=bp.init.Normal(-60., 2.))\n",
+ "E = bp.dyn.LifRef(3200, V_rest=-60., V_th=-50., V_reset=-60.,\n",
+ " tau=20., tau_ref=5., method='exp_auto',\n",
+ " V_initializer=bp.init.Normal(-60., 2.))\n",
"\n",
- "I = bp.neurons.LIF(800, V_rest=-60., V_th=-50., V_reset=-60.,\n",
- " tau=20., tau_ref=5., method='exp_auto',\n",
- " V_initializer=bp.init.Normal(-60., 2.))"
+ "I = bp.dyn.LifRef(800, V_rest=-60., V_th=-50., V_reset=-60.,\n",
+ " tau=20., tau_ref=5., method='exp_auto',\n",
+ " V_initializer=bp.init.Normal(-60., 2.))"
]
},
{
@@ -146,70 +148,126 @@
},
{
"cell_type": "markdown",
- "id": "abe09b1b",
- "metadata": {},
"source": [
- "Then the synaptic connections between these two groups can be defined as follows:"
- ]
+ "Before we define the synaptic projections between different populations, let's create a synapse model with the Exponential dynamics and conductance-based synaptic currents. "
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "24b642e81690f06a"
},
{
"cell_type": "code",
- "execution_count": 4,
- "id": "8be1733f",
+ "execution_count": 5,
+ "outputs": [],
+ "source": [
+ "class Exponential(bp.Projection): \n",
+ " def __init__(self, pre, post, delay, prob, g_max, tau, E):\n",
+ " super().__init__()\n",
+ " self.pron = bp.dyn.ProjAlignPost2(\n",
+ " pre=pre,\n",
+ " delay=delay,\n",
+ " # Event-driven computation\n",
+ " comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max), \n",
+ " syn=bp.dyn.Expon(size=post.num, tau=tau),# Exponential synapse\n",
+ " out=bp.dyn.COBA(E=E), # COBA network\n",
+ " post=post\n",
+ " )"
+ ],
"metadata": {
+ "collapsed": false,
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:23.612974Z",
- "end_time": "2023-04-15T13:35:25.688031Z"
+ "end_time": "2023-09-10T08:44:45.761555100Z",
+ "start_time": "2023-09-10T08:44:45.746060600Z"
}
},
- "outputs": [],
- "source": [
- "E2E = bp.synapses.Exponential(E, E, bp.conn.FixedProb(prob=0.02), g_max=0.6,\n",
- " tau=5., output=bp.synouts.COBA(E=0.),\n",
- " method='exp_auto')\n",
- "\n",
- "E2I = bp.synapses.Exponential(E, I, bp.conn.FixedProb(prob=0.02), g_max=0.6,\n",
- " tau=5., output=bp.synouts.COBA(E=0.),\n",
- " method='exp_auto')\n",
- "\n",
- "I2E = bp.synapses.Exponential(I, E, bp.conn.FixedProb(prob=0.02), g_max=6.7,\n",
- " tau=10., output=bp.synouts.COBA(E=-80.),\n",
- " method='exp_auto')\n",
- "\n",
- "I2I = bp.synapses.Exponential(I, I, bp.conn.FixedProb(prob=0.02), g_max=6.7,\n",
- " tau=10., output=bp.synouts.COBA(E=-80.),\n",
- " method='exp_auto')"
- ]
+ "id": "45b6804ed82895a"
},
{
"cell_type": "markdown",
"id": "13b3c3a9",
"metadata": {},
"source": [
- "Here we use the Exponential synapse model (``bp.synapses.Exponential``) to simulate synaptic connections. Among the parameters of the model, the first two denotes the pre- and post-synaptic neuron groups, respectively. The third one refers to the connection types. In this example, we use ``bp.conn.FixedProb``, which connects the presynaptic neurons to postsynaptic neurons with a given probability (detailed information is available in [Synaptic Connection](../tutorial_toolbox/synaptic_connections.ipynb)). The following three parameters describes the dynamic properties of the synapse, and the last one is the numerical integration method as that in the LIF model."
+ "Here we use the Align post projection method (``bp.dyn.ProjAlignPost2``) to simulate synaptic connections. Among the parameters of the model, the first two denotes the pre- and post-synaptic neuron groups, respectively. The third one refers to the connection types. In this example, we use ``bp.conn.FixedProb``, which connects the pre-synaptic neurons to postsynaptic neurons with a given probability (detailed information is available in [Synaptic Connection](../tutorial_toolbox/synaptic_connections.ipynb)). The following three parameters describes the dynamic properties of the synapse, and the last one is the numerical integration method as that in the LIF model."
]
},
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Then the synaptic connections between these two groups can be defined as follows:"
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "abe09b1b"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "outputs": [],
+ "source": [
+ "# projection from E to E\n",
+ "E2E = Exponential(E, E, 0., 0.02, 0.6, 5., 0.)\n",
+ "\n",
+ "# projection from E to I\n",
+ "E2I = Exponential(E, I, 0., 0.02, 0.6, 5., 0.)\n",
+ "\n",
+ "# projection from I to E\n",
+ "I2E = Exponential(I, E, 0., 0.02, 6.7, 10., -80.)\n",
+ "\n",
+ "# projection from I to I\n",
+ "I2I = Exponential(I, I, 0., 0.02, 6.7, 10., -80.)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-09-10T08:44:48.194090100Z",
+ "start_time": "2023-09-10T08:44:45.761555100Z"
+ }
+ },
+ "id": "8be1733f"
+ },
{
"cell_type": "markdown",
"id": "572fa775",
"metadata": {},
"source": [
- "After defining all the components, they can be combined to form a network:"
+ "Putting these together, we can get an E/I balanced network."
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 7,
"id": "f8a6c731",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:25.678171Z",
- "end_time": "2023-04-15T13:35:25.694111Z"
+ "end_time": "2023-09-10T08:44:48.203744400Z",
+ "start_time": "2023-09-10T08:44:48.192540100Z"
}
},
"outputs": [],
"source": [
- "net = bp.Network(E2E, E2I, I2E, I2I, E=E, I=I)"
+ "class EINet(bp.DynamicalSystem):\n",
+ " def __init__(self, ne=3200, ni=800):\n",
+ " super().__init__()\n",
+ " self.E = bp.dyn.LifRef(ne, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,\n",
+ " V_initializer=bp.init.Normal(-55., 2.))\n",
+ " self.I = bp.dyn.LifRef(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,\n",
+ " V_initializer=bp.init.Normal(-55., 2.))\n",
+ " self.E2E = Exponential(self.E, self.E, 0., 0.02, 0.6, 5., 0.)\n",
+ " self.E2I = Exponential(self.E, self.I, 0., 0.02, 0.6, 5., 0.)\n",
+ " self.I2E = Exponential(self.I, self.E, 0., 0.02, 6.7, 10., -80.)\n",
+ " self.I2I = Exponential(self.I, self.I, 0., 0.02, 6.7, 10., -80.)\n",
+ "\n",
+ " def update(self, inp=0.):\n",
+ " self.E2E()\n",
+ " self.E2I()\n",
+ " self.I2E()\n",
+ " self.I2I()\n",
+ " self.E(inp)\n",
+ " self.I(inp)\n",
+ " # monitor\n",
+ " return self.E.spike, self.I.spike"
]
},
{
@@ -217,9 +275,7 @@
"id": "0412deb5",
"metadata": {},
"source": [
- "In the definition, neurons and synapses are given to the network. The excitatory and inhibitory neuron groups (`E` and `I`) are passed with a name, for they will be specifically operated in the simulation (here they will be given with input currents).\n",
- "\n",
- "We have successfully constructed an E-I balanced network by using BrainPy's biult-in models. On the other hand, BrianPy also enables users to customize their own dynamic models such as neuron groups, synapses, and networks flexibly. In fact, ``brainpy.dyn.Network()`` is a simple example of customizing a network model. Please refer to [Dynamic Simulation](../tutorial_simulation/index.rst) for more information."
+ "We have successfully constructed an E-I balanced network by using BrainPy's biult-in models. On the other hand, BrianPy also enables users to customize their own dynamic models such as neuron groups, synapses, and networks flexibly. In fact, ``brainpy.DynSysGroup()`` is a simple example of customizing a network model. Please refer to [Dynamic Simulation](../tutorial_simulation/index.rst) for more information."
]
},
{
@@ -227,7 +283,9 @@
"id": "e3bcad34",
"metadata": {},
"source": [
- "### Running a simulation"
+ "### Running a simulation\n",
+ "\n",
+ "After building a SNN, we can use it for dynamic simulation. BrainPy provides multiple ways to simulate brain dynamics models. "
]
},
{
@@ -235,25 +293,24 @@
"id": "43ec39f4",
"metadata": {},
"source": [
- "After building a SNN, we can use it for dynamic simulation. To run a simulation, we need to wrap the network model into a **runner** first. BrainPy provides ``DSRunner`` in ``brainpy.dyn``, which will be expanded in the [Runners](../tutorial_simulation/index.rst) tutorial. Users can initialize ``DSRunner`` as followed:"
+ "First, BrainPy provides ``DSRunner`` in ``brainpy``, which will be expanded in the [Runners](../tutorial_simulation/index.rst) tutorial. Users can initialize ``DSRunner`` as followed:"
]
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 8,
"id": "8e16cd97",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:25.694111Z",
- "end_time": "2023-04-15T13:35:25.709754Z"
+ "end_time": "2023-09-10T08:44:48.983996200Z",
+ "start_time": "2023-09-10T08:44:48.203744400Z"
}
},
"outputs": [],
"source": [
- "runner = bp.DSRunner(net,\n",
- " monitors=['E.spike', 'I.spike'],\n",
- " inputs=[('E.input', 20.), ('I.input', 20.)],\n",
- " dt=0.1)"
+ "net = EINet()\n",
+ "\n",
+ "runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'])"
]
},
{
@@ -261,21 +318,19 @@
"id": "11473917",
"metadata": {},
"source": [
- "To make dynamic simulation more applicable and powerful, users can monitor variable trajectories and give inputs to target neuron groups. Here we monitor the ``spike`` variable in the ``E`` and ``I`` LIF model, which refers to the spking status of the neuron group, and give a constant input to both neuron groups. The time interval of numerical integration ``dt`` (with the default value of 0.1) can also be specified.\n",
- "\n",
- "More details of how to give inputs and monitors please refer to [Dynamic Simulation](../tutorial_simulation/index.rst).\n",
+ "To make dynamic simulation more applicable and powerful, users can monitor variable trajectories and give inputs to target neuron groups. Here we monitor the ``spike`` variable in the ``E`` and ``I`` LIF model, which refers to the spking status of the neuron group. More details of how to give inputs and monitors please refer to [Dynamic Simulation](../tutorial_simulation/index.rst).\n",
"\n",
- "After creating the runner, we can run a simulation by calling the runner:"
+ "After creating the runner, we can run a simulation by calling the runner, where the calling function receives the simulation time (usually in milliseconds) as the input. BrainPy achieves an extraordinary simulation speed with the assistance of just-in-time (JIT) compilation. Please refer to [Just-In-Time Compilation](../tutorial_math/brainpy_transform_concept.ipynb) for more details."
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 9,
"id": "a2a602d2",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:25.709754Z",
- "end_time": "2023-04-15T13:35:26.742003Z"
+ "end_time": "2023-09-10T08:44:50.192018700Z",
+ "start_time": "2023-09-10T08:44:48.983996200Z"
}
},
"outputs": [
@@ -285,7 +340,7 @@
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
- "model_id": "33ac887e0d7347a8aa9078635f0687a4"
+ "model_id": "cb881757388046c7876601f41a5e6afb"
}
},
"metadata": {},
@@ -293,34 +348,95 @@
}
],
"source": [
- "runner.run(100)"
+ "Is = bm.ones(1000) * 20. # 100 ms\n",
+ "_ = runner.run(inputs=Is)"
]
},
+ {
+ "cell_type": "markdown",
+ "source": [
+ "The monitored spikes are stored in the ``runner.mon``. "
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "acff9360881308ef"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "outputs": [],
+ "source": [
+ "E_sps = runner.mon['E.spike']\n",
+ "I_sps = runner.mon['I.spike']"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-09-10T08:44:50.207020900Z",
+ "start_time": "2023-09-10T08:44:50.192018700Z"
+ }
+ },
+ "id": "3cf93c4cf74a2205"
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Second, users can also use ``brainpy.math.for_loop`` for the efficient simulation of any BrainPy models. To do that, we need to define a running function which defines the one-step updating function of the model. "
+ ],
+ "metadata": {
+ "collapsed": false
+ },
+ "id": "19ec58dbf4c20634"
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "outputs": [],
+ "source": [
+ "net = EINet()\n",
+ "\n",
+ "def run_fun(i):\n",
+ " # i: the running index\n",
+ " # 20.: the input\n",
+ " return net.step_run(i, 20.)\n",
+ "\n",
+ "indices = np.arange(int(100. / bm.get_dt())) # 100. ms\n",
+ "E_sps, I_sps = bm.for_loop(run_fun, indices)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "ExecuteTime": {
+ "end_time": "2023-09-10T08:44:51.621343100Z",
+ "start_time": "2023-09-10T08:44:50.209021100Z"
+ }
+ },
+ "id": "85c630f3902ce1b7"
+ },
{
"cell_type": "markdown",
"id": "8452dec3",
"metadata": {},
"source": [
- "where the calling function receives the simulation time (usually in milliseconds) as the input. BrainPy achieves an extraordinary simulation speed with the assistance of just-in-time (JIT) compilation. Please refer to [Just-In-Time Compilation](../tutorial_math/brainpy_transform_concept.ipynb) for more details.\n",
"\n",
"The simulation results are stored as NumPy arrays in the monitors, and can be visualized easily:"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 12,
"id": "f3aab08c",
"metadata": {
"ExecuteTime": {
- "start_time": "2023-04-15T13:35:26.725106Z",
- "end_time": "2023-04-15T13:35:27.147108Z"
+ "end_time": "2023-09-10T08:44:52.164740Z",
+ "start_time": "2023-09-10T08:44:51.605619800Z"
}
},
"outputs": [
{
"data": {
"text/plain": "