Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Oct 28, 2023
1 parent ae3c966 commit 6e57e2b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 34 deletions.
5 changes: 3 additions & 2 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,12 @@ def __init__(
super().__init__(mode=delay.mode)
self.refs = {'delay': delay}
assert isinstance(delay, Delay)
delay.register_entry(delay_entry or self.name, time)
self._delay_entry = delay_entry or self.name
delay.register_entry(self._delay_entry, time)
self.indices = indices

def update(self):
return self.refs['delay'].at(self.name, *self.indices)
return self.refs['delay'].at(self._delay_entry, *self.indices)

def reset_state(self, *args, **kwargs):
pass
Expand Down
17 changes: 12 additions & 5 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ def run(i, I_pre, I_post):
pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post])
Args:
tau_s: float, ArrayType, Callable. The time constant of :math:`A_{pre}`.
tau_t: float, ArrayType, Callable. The time constant of :math:`A_{post}`.
A1: float, ArrayType, Callable. The increment of :math:`A_{pre}` produced by a spike.
A2: float, ArrayType, Callable. The increment of :math:`A_{post}` produced by a spike.
tau_s: float. The time constant of :math:`A_{pre}`.
tau_t: float. The time constant of :math:`A_{post}`.
A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value.
A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value.
W_max: float. The maximum weight.
pre: DynamicalSystem. The pre-synaptic neuron group.
delay: int, float. The pre spike delay length. (ms)
syn: DynamicalSystem. The synapse model.
Expand All @@ -133,6 +134,7 @@ def __init__(
tau_t: Union[float, ArrayType, Callable] = 33.7,
A1: Union[float, ArrayType, Callable] = 0.96,
A2: Union[float, ArrayType, Callable] = 0.53,
W_max: Optional[float] = None,
# others
out_label: Optional[str] = None,
name: Optional[str] = None,
Expand Down Expand Up @@ -176,6 +178,7 @@ def __init__(
self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t))

# synapse parameters
self.W_max = W_max
self.tau_s = parameter(tau_s, sizes=self.pre_num)
self.tau_t = parameter(tau_t, sizes=self.post_num)
self.A1 = parameter(A1, sizes=self.pre_num)
Expand All @@ -201,7 +204,7 @@ def update(self):
Apre = self.refs['pre_trace'].g
Apost = self.refs['post_trace'].g
delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike)
self.comm.update_STDP(delta_w)
self.comm.update_STDP(delta_w, constraints=self._weight_clip)

# currents
current = self.comm(x)
Expand All @@ -210,3 +213,7 @@ def update(self):
else:
self.refs['out'].bind_cond(current) # align pre
return current

def _weight_clip(self, w):
return w if self.W_max is None else bm.minimum(w, self.W_max)

83 changes: 56 additions & 27 deletions brainpy/_src/dyn/projections/tests/test_aligns.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

def test_ProjAlignPreMg1():
class EICOBA_PreAlign(bp.DynamicalSystem):
def __init__(self, scale=1., inp=20.):
def __init__(self, scale=1., inp=20., delay=None):
super().__init__()

self.inp = inp
Expand All @@ -22,31 +22,31 @@ def __init__(self, scale=1., inp=20.):
self.E2I = bp.dyn.ProjAlignPreMg1(
pre=self.E,
syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.),
delay=None,
delay=delay,
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6),
out=bp.dyn.COBA(E=0.),
post=self.I,
)
self.E2E = bp.dyn.ProjAlignPreMg1(
pre=self.E,
syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.),
delay=None,
delay=delay,
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6),
out=bp.dyn.COBA(E=0.),
post=self.E,
)
self.I2E = bp.dyn.ProjAlignPreMg1(
pre=self.I,
syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.),
delay=None,
delay=delay,
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7),
out=bp.dyn.COBA(E=-80.),
post=self.E,
)
self.I2I = bp.dyn.ProjAlignPreMg1(
pre=self.I,
syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.),
delay=None,
delay=delay,
comm=bp.dnn.CSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7),
out=bp.dyn.COBA(E=-80.),
post=self.I,
Expand All @@ -65,13 +65,19 @@ def update(self):
indices = np.arange(400)
spks = bm.for_loop(net.step_run, indices)
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)

net = EICOBA_PreAlign(0.5, delay=1.)
indices = np.arange(400)
spks = bm.for_loop(net.step_run, indices)
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)

plt.close()
bm.clear_buffer_memory()


def test_ProjAlignPostMg2():
class EICOBA_PostAlign(bp.DynamicalSystem):
def __init__(self, scale, inp=20., ltc=True):
def __init__(self, scale, inp=20., ltc=True, delay=None):
super().__init__()
self.inp = inp

Expand All @@ -86,31 +92,31 @@ def __init__(self, scale, inp=20., ltc=True):

self.E2E = bp.dyn.ProjAlignPostMg2(
pre=self.E,
delay=None,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.E.num), 0.6),
syn=bp.dyn.Expon.desc(self.E.varshape, tau=5.),
out=bp.dyn.COBA.desc(E=0.),
post=self.E,
)
self.E2I = bp.dyn.ProjAlignPostMg2(
pre=self.E,
delay=None,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.E.num, post=self.I.num), 0.6),
syn=bp.dyn.Expon.desc(self.I.varshape, tau=5.),
out=bp.dyn.COBA.desc(E=0.),
post=self.I,
)
self.I2E = bp.dyn.ProjAlignPostMg2(
pre=self.I,
delay=None,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.E.num), 6.7),
syn=bp.dyn.Expon.desc(self.E.varshape, tau=10.),
out=bp.dyn.COBA.desc(E=-80.),
post=self.E,
)
self.I2I = bp.dyn.ProjAlignPostMg2(
pre=self.I,
delay=None,
delay=delay,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=self.I.num, post=self.I.num), 6.7),
syn=bp.dyn.Expon.desc(self.I.varshape, tau=10.),
out=bp.dyn.COBA.desc(E=-80.),
Expand All @@ -131,6 +137,11 @@ def update(self):
spks = bm.for_loop(net.step_run, indices)
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)

net = EICOBA_PostAlign(0.5, delay=1.)
indices = np.arange(400)
spks = bm.for_loop(net.step_run, indices)
bp.visualize.raster_plot(indices * bm.dt, spks, show=True)

net = EICOBA_PostAlign(0.5, ltc=False)
indices = np.arange(400)
spks = bm.for_loop(net.step_run, indices)
Expand Down Expand Up @@ -178,7 +189,7 @@ def update(self, input):

def test_ProjAlignPost2():
class EINet(bp.DynSysGroup):
def __init__(self, scale):
def __init__(self, scale, delay=None):
super().__init__()
ne, ni = int(3200 * scale), int(800 * scale)
p = 80 / (ne + ni)
Expand All @@ -188,25 +199,25 @@ def __init__(self, scale):
self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.E2E = bp.dyn.ProjAlignPost2(pre=self.E,
delay=0.1,
delay=delay,
comm=bp.dnn.EventJitFPHomoLinear(ne, ne, prob=p, weight=0.6),
syn=bp.dyn.Expon(size=ne, tau=5.),
out=bp.dyn.COBA(E=0.),
post=self.E)
self.E2I = bp.dyn.ProjAlignPost2(pre=self.E,
delay=0.1,
delay=delay,
comm=bp.dnn.EventJitFPHomoLinear(ne, ni, prob=p, weight=0.6),
syn=bp.dyn.Expon(size=ni, tau=5.),
out=bp.dyn.COBA(E=0.),
post=self.I)
self.I2E = bp.dyn.ProjAlignPost2(pre=self.I,
delay=0.1,
delay=delay,
comm=bp.dnn.EventJitFPHomoLinear(ni, ne, prob=p, weight=6.7),
syn=bp.dyn.Expon(size=ne, tau=10.),
out=bp.dyn.COBA(E=-80.),
post=self.E)
self.I2I = bp.dyn.ProjAlignPost2(pre=self.I,
delay=0.1,
delay=delay,
comm=bp.dnn.EventJitFPHomoLinear(ni, ni, prob=p, weight=6.7),
syn=bp.dyn.Expon(size=ni, tau=10.),
out=bp.dyn.COBA(E=-80.),
Expand All @@ -221,10 +232,16 @@ def update(self, inp):
self.I(inp)
return self.E.spike

model = EINet(0.5)
model = EINet(0.5, delay=1.)
indices = bm.arange(400)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)

model = EINet(0.5, delay=None)
indices = bm.arange(400)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)

bm.clear_buffer_memory()
plt.close()

Expand Down Expand Up @@ -267,7 +284,7 @@ def update(self, input):

def test_ProjAlignPreMg1_v2():
class EINet(bp.DynSysGroup):
def __init__(self, scale=1.):
def __init__(self, scale=1., delay=None):
super().__init__()
ne, ni = int(3200 * scale), int(800 * scale)
p = 80 / (4000 * scale)
Expand All @@ -277,25 +294,25 @@ def __init__(self, scale=1.):
V_initializer=bp.init.Normal(-55., 2.))
self.E2E = bp.dyn.ProjAlignPreMg1(pre=self.E,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
delay=0.1,
delay=delay,
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.E)
self.E2I = bp.dyn.ProjAlignPreMg1(pre=self.E,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
delay=0.1,
delay=delay,
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.I)
self.I2E = bp.dyn.ProjAlignPreMg1(pre=self.I,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
delay=0.1,
delay=delay,
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.E)
self.I2I = bp.dyn.ProjAlignPreMg1(pre=self.I,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
delay=0.1,
delay=delay,
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.I)
Expand All @@ -313,13 +330,19 @@ def update(self, inp):
indices = bm.arange(400)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)

model = EINet(delay=1.)
indices = bm.arange(400)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)

bm.clear_buffer_memory()
plt.close()


def test_ProjAlignPreMg2():
class EINet(bp.DynSysGroup):
def __init__(self, scale=1.):
def __init__(self, scale=1., delay=None):
super().__init__()
ne, ni = int(3200 * scale), int(800 * scale)
p = 80 / (4000 * scale)
Expand All @@ -328,25 +351,25 @@ def __init__(self, scale=1.):
self.I = bp.dyn.LifRefLTC(ni, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.E2E = bp.dyn.ProjAlignPreMg2(pre=self.E,
delay=0.1,
delay=delay,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
comm=bp.dnn.JitFPHomoLinear(ne, ne, prob=p, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.E)
self.E2I = bp.dyn.ProjAlignPreMg2(pre=self.E,
delay=0.1,
delay=delay,
syn=bp.dyn.Expon.desc(size=ne, tau=5.),
comm=bp.dnn.JitFPHomoLinear(ne, ni, prob=p, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.I)
self.I2E = bp.dyn.ProjAlignPreMg2(pre=self.I,
delay=0.1,
delay=delay,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
comm=bp.dnn.JitFPHomoLinear(ni, ne, prob=p, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.E)
self.I2I = bp.dyn.ProjAlignPreMg2(pre=self.I,
delay=0.1,
delay=delay,
syn=bp.dyn.Expon.desc(size=ni, tau=10.),
comm=bp.dnn.JitFPHomoLinear(ni, ni, prob=p, weight=6.7),
out=bp.dyn.COBA(E=-80.),
Expand All @@ -361,10 +384,16 @@ def update(self, inp):
self.I(inp)
return self.E.spike

model = EINet()
model = EINet(scale=0.2, delay=None)
indices = bm.arange(400)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)

model = EINet(scale=0.2, delay=1.)
indices = bm.arange(400)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)

bm.clear_buffer_memory()
plt.close()

Expand Down

0 comments on commit 6e57e2b

Please sign in to comment.