Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

brainpy.math.defjvp and brainpy.math.XLACustomOp.defjvp #554

Merged
merged 14 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
</p>


BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Numba](https://github.com/numba/numba), and other JIT compilers). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.
BrainPy is a flexible, efficient, and extensible framework for computational neuroscience and brain-inspired computation based on the Just-In-Time (JIT) compilation (built on top of [JAX](https://github.com/google/jax), [Taichi](https://github.com/taichi-dev/taichi), [Numba](https://github.com/numba/numba), and others). It provides an integrative ecosystem for brain dynamics programming, including brain dynamics **building**, **simulation**, **training**, **analysis**, etc.

- **Website (documentation and APIs)**: https://brainpy.readthedocs.io/en/latest
- **Source**: https://github.com/brainpy/BrainPy
Expand Down Expand Up @@ -77,6 +77,7 @@ We provide a Binder environment for BrainPy. You can use the following button to
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
- [《神经计算建模实战》 (Neural Modeling in Action)](https://github.com/c-xy17/NeuralModeling)
- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course)


Expand Down
4 changes: 2 additions & 2 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.6"
__version__ = "2.4.6.post2"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down Expand Up @@ -75,7 +75,7 @@
)
NeuGroup = NeuGroupNS = dyn.NeuDyn

# shared parameters
# common tools
from brainpy._src.context import (share as share)
from brainpy._src.helpers import (reset_state as reset_state,
save_state as save_state,
Expand Down
38 changes: 19 additions & 19 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,41 @@
import os
import sys
from jax.lib import xla_client


__all__ = [
'import_taichi',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]


_minimal_brainpylib_version = '0.1.10'
_minimal_taichi_version = (1, 7, 0)

taichi = None
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
'> pip install taichi==1.7.0 -U')
os.environ["TI_LOG_LEVEL"] = "error"


def import_taichi():
global taichi
if taichi is None:
try:
import taichi as taichi # noqa
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Taichi is needed. Please install taichi through:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)

if taichi.__version__ < _minimal_taichi_version:
raise RuntimeError(
f'We need taichi>={_minimal_taichi_version}. '
f'Currently you can install taichi>={_minimal_taichi_version} through taichi-nightly:\n\n'
'> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly'
)
with open(os.devnull, 'w') as devnull:
old_stdout = sys.stdout
sys.stdout = devnull
try:
import taichi as taichi # noqa
except ModuleNotFoundError:
raise ModuleNotFoundError(taichi_install_info)
finally:
sys.stdout = old_stdout

if taichi.__version__ != _minimal_taichi_version:
raise RuntimeError(taichi_install_info)
return taichi


Expand Down Expand Up @@ -82,6 +85,3 @@ def import_brainpylib_gpu_ops():
'See https://brainpy.readthedocs.io for installation instructions.')

return brainpylib_gpu_ops



198 changes: 194 additions & 4 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ class STDP_Song2000(Projection):

\begin{aligned}
\frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
\frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s}+A_1\delta(t-t_{sp}), \\
\frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t}+A_2\delta(t-t_{sp}), \\
\frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\
\frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\
\end{aligned}

where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
Expand All @@ -64,8 +64,8 @@ class STDP_Song2000(Projection):
class STDPNet(bp.DynamicalSystem):
def __init__(self, num_pre, num_post):
super().__init__()
self.pre = bp.dyn.LifRef(num_pre, name='neu1')
self.post = bp.dyn.LifRef(num_post, name='neu2')
self.pre = bp.dyn.LifRef(num_pre)
self.post = bp.dyn.LifRef(num_post)
self.syn = bp.dyn.STDP_Song2000(
pre=self.pre,
delay=1.,
Expand Down Expand Up @@ -219,3 +219,193 @@ def update(self):
return current


# class PairedSTDP(Projection):
# r"""Paired spike-time-dependent plasticity model.
#
# This model filters the synaptic currents according to the variables: :math:`w`.
#
# .. math::
#
# I_{syn}^+(t) = I_{syn}^-(t) * w
#
# where :math:`I_{syn}^-(t)` and :math:`I_{syn}^+(t)` are the synaptic currents before
# and after STDP filtering, :math:`w` measures synaptic efficacy because each time a presynaptic neuron emits a pulse,
# the conductance of the synapse will increase w.
#
# The dynamics of :math:`w` is governed by the following equation:
#
# .. math::
#
# \begin{aligned}
# \frac{dw}{dt} & = & -A_{post}\delta(t-t_{sp}) + A_{pre}\delta(t-t_{sp}), \\
# \frac{dA_{pre}}{dt} & = & -\frac{A_{pre}}{\tau_s} + A_1\delta(t-t_{sp}), \\
# \frac{dA_{post}}{dt} & = & -\frac{A_{post}}{\tau_t} + A_2\delta(t-t_{sp}), \\
# \end{aligned}
#
# where :math:`t_{sp}` denotes the spike time and :math:`A_1` is the increment
# of :math:`A_{pre}`, :math:`A_2` is the increment of :math:`A_{post}` produced by a spike.
#
# Here is an example of the usage of this class::
#
# import brainpy as bp
# import brainpy.math as bm
#
# class STDPNet(bp.DynamicalSystem):
# def __init__(self, num_pre, num_post):
# super().__init__()
# self.pre = bp.dyn.LifRef(num_pre)
# self.post = bp.dyn.LifRef(num_post)
# self.syn = bp.dyn.STDP_Song2000(
# pre=self.pre,
# delay=1.,
# comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
# weight=bp.init.Uniform(max_val=0.1)),
# syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
# out=bp.dyn.COBA.desc(E=0.),
# post=self.post,
# tau_s=16.8,
# tau_t=33.7,
# A1=0.96,
# A2=0.53,
# )
#
# def update(self, I_pre, I_post):
# self.syn()
# self.pre(I_pre)
# self.post(I_post)
# conductance = self.syn.refs['syn'].g
# Apre = self.syn.refs['pre_trace'].g
# Apost = self.syn.refs['post_trace'].g
# current = self.post.sum_inputs(self.post.V)
# return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
#
# duration = 300.
# I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
# [5, 15, 15, 15, 15, 15, 100, 15, 15, 15, 15, 15, duration - 255])
# I_post = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
# [10, 15, 15, 15, 15, 15, 90, 15, 15, 15, 15, 15, duration - 250])
#
# net = STDPNet(1, 1)
# def run(i, I_pre, I_post):
# pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
# return pre_spike, post_spike, g, Apre, Apost, current, W
#
# indices = bm.arange(0, duration, bm.dt)
# pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post])
#
# Args:
# tau_s: float. The time constant of :math:`A_{pre}`.
# tau_t: float. The time constant of :math:`A_{post}`.
# A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value.
# A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value.
# W_max: float. The maximum weight.
# W_min: float. The minimum weight.
# pre: DynamicalSystem. The pre-synaptic neuron group.
# delay: int, float. The pre spike delay length. (ms)
# syn: DynamicalSystem. The synapse model.
# comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers.
# out: DynamicalSystem. The synaptic current output models.
# post: DynamicalSystem. The post-synaptic neuron group.
# out_label: str. The output label.
# name: str. The model name.
# """
#
# def __init__(
# self,
# pre: JointType[DynamicalSystem, SupportAutoDelay],
# delay: Union[None, int, float],
# syn: ParamDescriber[DynamicalSystem],
# comm: JointType[DynamicalSystem, SupportSTDP],
# out: ParamDescriber[JointType[DynamicalSystem, BindCondData]],
# post: DynamicalSystem,
# # synapse parameters
# tau_s: float = 16.8,
# tau_t: float = 33.7,
# lambda_: float = 0.96,
# alpha: float = 0.53,
# mu: float = 0.53,
# W_max: Optional[float] = None,
# W_min: Optional[float] = None,
# # others
# out_label: Optional[str] = None,
# name: Optional[str] = None,
# mode: Optional[bm.Mode] = None,
# ):
# super().__init__(name=name, mode=mode)
#
# # synaptic models
# check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
# check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP])
# check.is_instance(syn, ParamDescriber[DynamicalSystem])
# check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
# check.is_instance(post, DynamicalSystem)
# self.pre_num = pre.num
# self.post_num = post.num
# self.comm = comm
# self._is_align_post = issubclass(syn.cls, AlignPost)
#
# # delay initialization
# delay_cls = register_delay_by_return(pre)
# delay_cls.register_entry(self.name, delay)
#
# # synapse and output initialization
# if self._is_align_post:
# syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post,
# proj_name=self.name)
# else:
# syn_cls = align_pre2_add_bef_update(syn, delay, delay_cls, self.name + '-pre')
# out_cls = out()
# add_inp_fun(out_label, self.name, out_cls, post)
#
# # references
# self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
# self.refs['delay'] = delay_cls
# self.refs['syn'] = syn_cls # invisible to ``self.node()``
# self.refs['out'] = out_cls # invisible to ``self.node()``
# self.refs['comm'] = comm
#
# # tracing pre-synaptic spikes using Exponential model
# self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s))
#
# # tracing post-synaptic spikes using Exponential model
# self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t))
#
# # synapse parameters
# self.W_max = W_max
# self.W_min = W_min
# self.tau_s = tau_s
# self.tau_t = tau_t
# self.A1 = A1
# self.A2 = A2
#
# def update(self):
# # pre-synaptic spikes
# pre_spike = self.refs['delay'].at(self.name) # spike
# # pre-synaptic variables
# if self._is_align_post:
# # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance
# x = pre_spike
# else:
# # For AlignPre, we need the "pre synapse variable @ comm matrix" for computing post conductance
# x = _get_return(self.refs['syn'].return_info()) # pre-synaptic variable
#
# # post spikes
# if not hasattr(self.refs['post'], 'spike'):
# raise AttributeError(f'{self} needs a "spike" variable for the post-synaptic neuron group.')
# post_spike = self.refs['post'].spike
#
# # weight updates
# Apost = self.refs['post_trace'].g
# self.comm.stdp_update(on_pre={"spike": pre_spike, "trace": -Apost * self.A2}, w_min=self.W_min, w_max=self.W_max)
# Apre = self.refs['pre_trace'].g
# self.comm.stdp_update(on_post={"spike": post_spike, "trace": Apre * self.A1}, w_min=self.W_min, w_max=self.W_max)
#
# # synaptic currents
# current = self.comm(x)
# if self._is_align_post:
# self.refs['syn'].add_current(current) # synapse post current
# else:
# self.refs['out'].bind_cond(current) # align pre
# return current


Loading