diff --git a/README.md b/README.md index fa553633f..716dbd900 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ We provide a Binder environment for BrainPy. You can use the following button to - **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming. - **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation. - **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling. -- [第一届神经计算建模与编程培训班 (BrainPy First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course) +- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course) ## Citing diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 2301bab7a..314ffb19c 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -3,6 +3,7 @@ from typing import Dict, Optional, Union, Callable +import numba import jax import numpy as np import jax.numpy as jnp @@ -227,6 +228,45 @@ def update(self, x): return x +def event_mm(pre_spike, post_inc, weight, w_min, w_max): + return weight + + +@numba.njit +def event_mm_imp(outs, ins): + pre_spike, post_inc, weight, w_min, w_max = ins + w_min = w_min[()] + w_max = w_max[()] + outs = outs + outs.fill(weight) + for i in range(pre_spike.shape[0]): + if pre_spike[i]: + outs[i] = np.clip(outs[i] + post_inc, w_min, w_max) + + +event_left_mm = bm.CustomOpByNumba(event_mm, event_mm_imp, multiple_results=False) + + +def event_mm2(post_spike, pre_inc, weight, w_min, w_max): + return weight + + +@numba.njit +def event_mm_imp2(outs, ins): + post_spike, pre_inc, weight, w_min, w_max = ins + w_min = w_min[()] + w_max = w_max[()] + outs = outs + outs.fill(weight) + for j in range(post_spike.shape[0]): + if post_spike[j]: + outs[:, j] = np.clip(outs[:, j] + pre_inc, w_min, w_max) + + +event_right_mm = bm.CustomOpByNumba(event_mm2, event_mm_imp2, multiple_results=False) + + + class AllToAll(Layer, SupportSTDP): """Synaptic matrix multiplication with All2All connections. @@ -289,20 +329,15 @@ def update(self, pre_val): post_val = pre_val @ self.weight return post_val - def update_STDP(self, dW, constraints=None): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)): - raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}') - if self.weight.shape != dW.shape: - raise ValueError(f'The shape of delta_weight {dW.shape} ' - f'should be the same as the shape of weight {self.weight.shape}.') + def stdp_update_on_pre(self, pre_spike, trace, w_min=None, w_max=None): if not isinstance(self.weight, bm.Variable): self.tracing_variable('weight', self.weight, self.weight.shape) - self.weight += dW - if constraints is not None: - self.weight.value = constraints(self.weight) + self.weight.value = event_left_mm(pre_spike, trace, self.weight, w_min, w_max) + def stdp_update_on_post(self, post_spike, trace, w_min=None, w_max=None): + if not isinstance(self.weight, bm.Variable): + self.tracing_variable('weight', self.weight, self.weight.shape) + self.weight.value = event_right_mm(post_spike, trace, self.weight, w_min, w_max) class OneToOne(Layer, SupportSTDP): @@ -338,21 +373,6 @@ def __init__( def update(self, pre_val): return pre_val * self.weight - def update_STDP(self, dW, constraints=None): - if isinstance(self.weight, float): - raise ValueError(f'Cannot update the weight of a constant node.') - if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)): - raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}') - dW = dW.sum(axis=0) - if self.weight.shape != dW.shape: - raise ValueError(f'The shape of delta_weight {dW.shape} ' - f'should be the same as the shape of weight {self.weight.shape}.') - if not isinstance(self.weight, bm.Variable): - self.tracing_variable('weight', self.weight, self.weight.shape) - self.weight += dW - if constraints is not None: - self.weight.value = constraints(self.weight) - class MaskedLinear(Layer, SupportSTDP): r"""Synaptic matrix multiplication with masked dense computation. diff --git a/brainpy/_src/dyn/others/input.py b/brainpy/_src/dyn/others/input.py index 92a2390b4..60632dc9f 100644 --- a/brainpy/_src/dyn/others/input.py +++ b/brainpy/_src/dyn/others/input.py @@ -228,9 +228,5 @@ def update(self): def return_info(self): return self.spike - def reset_state(self, batch_size=None, **kwargs): - self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type), - self.varshape, - batch_size, - axis_names=self.sharding, - batch_axis_name=bm.sharding.BATCH_AXIS) + def reset_state(self, batch_or_mode=None, **kwargs): + self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode) diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 5894a1452..c51332e44 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -4,7 +4,6 @@ from brainpy._src.delay import register_delay_by_return from brainpy._src.dyn.synapses.abstract_models import Expon from brainpy._src.dynsys import DynamicalSystem, Projection -from brainpy._src.initialize import parameter from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP) from brainpy.types import ArrayType @@ -111,7 +110,8 @@ def run(i, I_pre, I_post): A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value. A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value. W_max: float. The maximum weight. - pre: DynamicalSystem. The pre-synaptic neuron group. + W_min: float. The minimum weight. + pre: DynamicalSystem. The pre-synaptic neuron group. delay: int, float. The pre spike delay length. (ms) syn: DynamicalSystem. The synapse model. comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers. @@ -135,6 +135,7 @@ def __init__( A1: Union[float, ArrayType, Callable] = 0.96, A2: Union[float, ArrayType, Callable] = 0.53, W_max: Optional[float] = None, + W_min: Optional[float] = None, # others out_label: Optional[str] = None, name: Optional[str] = None, @@ -144,21 +145,21 @@ def __init__( # synaptic models check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay]) - check.is_instance(syn, ParamDescriber[DynamicalSystem]) check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP]) + check.is_instance(syn, ParamDescriber[DynamicalSystem]) check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]]) check.is_instance(post, DynamicalSystem) self.pre_num = pre.num self.post_num = post.num self.comm = comm - self.syn = syn + self._is_align_post = issubclass(syn.cls, AlignPost) # delay initialization delay_cls = register_delay_by_return(pre) delay_cls.register_entry(self.name, delay) # synapse and output initialization - if issubclass(syn.cls, AlignPost): + if self._is_align_post: syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post, proj_name=self.name) else: @@ -171,24 +172,27 @@ def __init__( self.refs['delay'] = delay_cls self.refs['syn'] = syn_cls # invisible to ``self.node()`` self.refs['out'] = out_cls # invisible to ``self.node()`` + self.refs['comm'] = comm # tracing pre-synaptic spikes using Exponential model self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s)) + # tracing post-synaptic spikes using Exponential model self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t)) # synapse parameters self.W_max = W_max - self.tau_s = parameter(tau_s, sizes=self.pre_num) - self.tau_t = parameter(tau_t, sizes=self.post_num) - self.A1 = parameter(A1, sizes=self.pre_num) - self.A2 = parameter(A2, sizes=self.post_num) + self.W_min = W_min + self.tau_s = tau_s + self.tau_t = tau_t + self.A1 = A1 + self.A2 = A2 def update(self): # pre-synaptic spikes pre_spike = self.refs['delay'].at(self.name) # spike # pre-synaptic variables - if issubclass(self.syn.cls, AlignPost): + if self._is_align_post: # For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance x = pre_spike else: @@ -201,19 +205,17 @@ def update(self): post_spike = self.refs['post'].spike # weight updates - Apre = self.refs['pre_trace'].g Apost = self.refs['post_trace'].g - delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike) - self.comm.update_STDP(delta_w, constraints=self._weight_clip) + self.comm.stdp_update_on_pre(pre_spike, -Apost * self.A2, self.W_min, self.W_max) + Apre = self.refs['pre_trace'].g + self.comm.stdp_update_on_post(post_spike, Apre * self.A1, self.W_min, self.W_max) - # currents + # synaptic currents current = self.comm(x) - if issubclass(self.syn.cls, AlignPost): + if self._is_align_post: self.refs['syn'].add_current(current) # synapse post current else: self.refs['out'].bind_cond(current) # align pre return current - def _weight_clip(self, w): - return w if self.W_max is None else bm.minimum(w, self.W_max) diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index e33644f26..001afc02e 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -21,8 +21,8 @@ def __init__(self, num_pre, num_post): pre=self.pre, delay=1., # comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num), - # weight=bp.init.Uniform(-0.1, 0.1)), - comm=bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(-0.1, 0.1)), + # weight=bp.init.Uniform(0., 0.1)), + comm=bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)), syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.), out=bp.dyn.COBA.desc(E=0.), post=self.post, @@ -30,6 +30,8 @@ def __init__(self, num_pre, num_post): tau_t=33.7, A1=0.96, A2=0.53, + W_min=0., + W_max=1. ) def update(self, I_pre, I_post): diff --git a/brainpy/_src/dyn/rates/tests/test_reservoir.py b/brainpy/_src/dyn/rates/tests/test_reservoir.py index 371c7aa89..34d00c909 100644 --- a/brainpy/_src/dyn/rates/tests/test_reservoir.py +++ b/brainpy/_src/dyn/rates/tests/test_reservoir.py @@ -15,7 +15,7 @@ class Test_Reservoir(parameterized.TestCase): def test_Reservoir(self, mode): bm.random.seed() input = bm.random.randn(10, 3) - layer = bp.syn.Reservoir(input_shape=3, + layer = bp.dyn.Reservoir(input_shape=3, num_out=5, mode=mode) if mode in [bm.NonBatchingMode()]: diff --git a/brainpy/_src/math/op_register/tests/test_ei_net.py b/brainpy/_src/math/op_register/tests/test_ei_net.py deleted file mode 100644 index 28d106cb2..000000000 --- a/brainpy/_src/math/op_register/tests/test_ei_net.py +++ /dev/null @@ -1,77 +0,0 @@ -import brainpy.math as bm -import brainpy as bp -from jax.core import ShapedArray - - -def abs_eval(events, indices, indptr, *, weight, post_num): - return [ShapedArray((post_num,), bm.float32), ] - - -def con_compute(outs, ins): - post_val, = outs - post_val.fill(0) - events, indices, indptr, weight, _ = ins - weight = weight[()] - for i in range(events.size): - if events[i]: - for j in range(indptr[i], indptr[i + 1]): - index = indices[j] - post_val[index] += weight - - -event_sum = bm.XLACustomOp(eval_shape=abs_eval, cpu_func=con_compute) - - -class ExponentialV2(bp.synapses.TwoEndConn): - """Exponential synapse model using customized operator written in C++.""" - - def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0.): - super(ExponentialV2, self).__init__(pre=pre, post=post, conn=conn) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - - # parameters - self.E = E - self.tau = tau - self.delay = delay - self.g_max = g_max - self.pre2post = self.conn.require('pre2post') - - # variables - self.g = bm.Variable(bm.zeros(self.post.num)) - - # function - self.integral = bp.odeint(lambda g, t: -g / self.tau, method='exp_auto') - - def update(self): - self.g.value = self.integral(self.g, bp.share['t']) - self.g += event_sum(self.pre.spike, - self.pre2post[0], - self.pre2post[1], - weight=self.g_max, - post_num=self.post.num)[0] - self.post.input += self.g * (self.E - self.post.V) - - -class EINet(bp.DynSysGroup): - def __init__(self, scale): - super().__init__() - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., - V_initializer=bp.init.Normal(-55., 2.), method='exp_auto') - self.E = bp.neurons.LIF(int(3200 * scale), **pars) - self.I = bp.neurons.LIF(int(800 * scale), **pars) - - # synapses - self.E2E = ExponentialV2(self.E, self.E, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) - self.E2I = ExponentialV2(self.E, self.I, bp.conn.FixedProb(prob=0.02), E=0., g_max=0.6 / scale, tau=5.) - self.I2E = ExponentialV2(self.I, self.E, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) - self.I2I = ExponentialV2(self.I, self.I, bp.conn.FixedProb(prob=0.02), E=-80., g_max=6.7 / scale, tau=10.) - - -def test1(): - bm.set_platform('cpu') - net2 = EINet(scale=0.1) - runner = bp.DSRunner(net2, inputs=[('E.input', 20.), ('I.input', 20.)]) - r = runner.predict(100., eval_time=True) - bm.clear_buffer_memory() diff --git a/brainpy/_src/math/tests/test_op_register.py b/brainpy/_src/math/tests/test_op_register.py deleted file mode 100644 index 6917202ad..000000000 --- a/brainpy/_src/math/tests/test_op_register.py +++ /dev/null @@ -1,141 +0,0 @@ -# -*- coding: utf-8 -*- - -import unittest - -import jax -import matplotlib.pyplot as plt - -import brainpy as bp -import brainpy.math as bm - - -bm.random.seed() -bm.set_platform('cpu') - - -def abs_eval(events, indices, indptr, post_val, values): - return [post_val] - - -def event_sum_op(outs, ins): - events, indices, indptr, post, values = ins - v = values[()] - outs, = outs - outs.fill(0) - for i in range(len(events)): - if events[i]: - for j in range(indptr[i], indptr[i + 1]): - index = indices[j] - outs[index] += v - - -event_sum2 = bm.XLACustomOp(name='event_sum2', cpu_func=event_sum_op, eval_shape=abs_eval) - - -class ExponentialSyn(bp.TwoEndConn): - def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., - method='exp_auto'): - super(ExponentialSyn, self).__init__(pre=pre, post=post, conn=conn) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - - # parameters - self.E = E - self.tau = tau - self.delay = delay - self.g_max = g_max - self.pre2post = self.conn.require('pre2post') - - # variables - self.g = bm.Variable(bm.zeros(self.post.num)) - - # function - self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method) - - def update(self, tdi): - self.g.value = self.integral(self.g, tdi['t'], dt=tdi['dt']) - self.g += bm.pre2post_event_sum(self.pre.spike, self.pre2post, self.post.num, self.g_max) - self.post.input += self.g * (self.E - self.post.V) - - -class ExponentialSyn3(bp.TwoEndConn): - def __init__(self, pre, post, conn, g_max=1., delay=0., tau=8.0, E=0., - method='exp_auto'): - super(ExponentialSyn3, self).__init__(pre=pre, post=post, conn=conn) - self.check_pre_attrs('spike') - self.check_post_attrs('input', 'V') - - # parameters - self.E = E - self.tau = tau - self.delay = delay - self.g_max = g_max - self.pre2post = self.conn.require('pre2post') - - # variables - self.g = bm.Variable(bm.zeros(self.post.num)) - - # function - self.integral = bp.odeint(lambda g, t: -g / self.tau, method=method) - - def update(self, tdi): - self.g.value = self.integral(self.g, tdi['t'], tdi['dt']) - # Customized operator - # ------------------------------------------------------------------------------------------------------------ - post_val = bm.zeros(self.post.num) - r = event_sum2(self.pre.spike, self.pre2post[0], self.pre2post[1], post_val, self.g_max) - self.g += r[0] - # ------------------------------------------------------------------------------------------------------------ - self.post.input += self.g * (self.E - self.post.V) - - -class EINet(bp.Network): - def __init__(self, syn_class, scale=1.0, method='exp_auto', ): - super(EINet, self).__init__() - - # network size - num_exc = int(3200 * scale) - num_inh = int(800 * scale) - - # neurons - pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) - self.E = bp.neurons.LIF(num_exc, **pars, method=method) - self.I = bp.neurons.LIF(num_inh, **pars, method=method) - self.E.V[:] = bm.random.randn(num_exc) * 2 - 55. - self.I.V[:] = bm.random.randn(num_inh) * 2 - 55. - - # synapses - we = 0.6 / scale # excitatory synaptic weight (voltage) - wi = 6.7 / scale # inhibitory synaptic weight - self.E2E = syn_class(self.E, self.E, bp.conn.FixedProb(0.02), E=0., g_max=we, tau=5., method=method) - self.E2I = syn_class(self.E, self.I, bp.conn.FixedProb(0.02), E=0., g_max=we, tau=5., method=method) - self.I2E = syn_class(self.I, self.E, bp.conn.FixedProb(0.02), E=-80., g_max=wi, tau=10., method=method) - self.I2I = syn_class(self.I, self.I, bp.conn.FixedProb(0.02), E=-80., g_max=wi, tau=10., method=method) - - -class TestOpRegister(unittest.TestCase): - def test_op(self): - bm.random.seed(123) - fig, gs = bp.visualize.get_figure(1, 2, 4, 5) - - net = EINet(ExponentialSyn, scale=0.1, method='euler') - runner = bp.DSRunner( - net, - inputs=[(net.E.input, 20.), (net.I.input, 20.)], - monitors={'E.spike': net.E.spike}, - ) - t, _ = runner.run(100., eval_time=True) - print(t) - ax = fig.add_subplot(gs[0, 0]) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], ax=ax) - - net3 = EINet(ExponentialSyn3, scale=0.1, method='euler') - runner3 = bp.DSRunner( - net3, - inputs=[(net3.E.input, 20.), (net3.I.input, 20.)], - monitors={'E.spike': net3.E.spike}, - ) - t, _ = runner3.run(100., eval_time=True) - print(t) - plt.close() - bm.clear_buffer_memory() diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 177b60aa6..f356f44b3 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -490,6 +490,12 @@ def update_STDP( ): raise NotImplementedError + def stdp_update_on_pre(self, pre_spike, trace, *args, **kwargs): + raise NotImplementedError + + def stdp_update_on_post(self, post_spike, trace, *args, **kwargs): + raise NotImplementedError + T = TypeVar('T') diff --git a/examples/dynamics_simulation/stdp.py b/examples/dynamics_simulation/stdp.py new file mode 100644 index 000000000..edaf90e44 --- /dev/null +++ b/examples/dynamics_simulation/stdp.py @@ -0,0 +1,64 @@ +""" +Reproduce the following STDP paper: + +- Song, S., Miller, K. & Abbott, L. Competitive Hebbian learning through spike-timing-dependent + synaptic plasticity. Nat Neurosci 3, 919–926 (2000). https://doi.org/10.1038/78829 +""" + +import matplotlib.pyplot as plt +import numpy as np + +import brainpy as bp +import brainpy.math as bm + + +class STDPNet(bp.DynSysGroup): + def __init__(self, num_poisson, num_lif=1, g_max=0.01): + super().__init__() + + self.g_max = g_max + + # neuron groups + self.noise = bp.dyn.PoissonGroup(num_poisson, freqs=15.) + self.group = bp.dyn.Lif(num_lif, V_reset=-60., V_rest=-74, V_th=-54, tau=10., + V_initializer=bp.init.Normal(-60., 1.)) + + # synapses + syn = bp.dyn.Expon.desc(num_lif, tau=5.) + out = bp.dyn.COBA.desc(E=0.) + comm = bp.dnn.AllToAll(num_poisson, num_lif, bp.init.Uniform(0., g_max)) + self.syn = bp.dyn.STDP_Song2000(self.noise, None, syn, comm, out, self.group, + tau_s=20, tau_t=20, W_max=g_max, W_min=0., + A1=0.01 * g_max, A2=0.0105 * g_max) + + def update(self, *args, **kwargs): + self.noise() + self.syn() + self.group() + return self.syn.comm.weight.flatten()[:10] + + +def run_model(): + net = STDPNet(1000, 1) + indices = np.arange(int(100.0e3 / bm.dt)) # 100 s + ws = bm.for_loop(net.step_run, indices, progress_bar=True) + weight = bm.as_numpy(net.syn.comm.weight.flatten()) + + fig, gs = bp.visualize.get_figure(3, 1, 3, 10) + fig.add_subplot(gs[0, 0]) + plt.plot(weight / net.g_max, '.k') + plt.xlabel('Weight / gmax') + + fig.add_subplot(gs[1, 0]) + plt.hist(weight / net.g_max, 20) + plt.xlabel('Weight / gmax') + + fig.add_subplot(gs[2, 0]) + plt.plot(indices * bm.dt, bm.as_numpy(ws) / net.g_max) + plt.xlabel('Time (s)') + plt.ylabel('Weight / gmax') + plt.show() + + +if __name__ == '__main__': + run_model()