Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 29, 2023
2 parents b28cb6a + c346f29 commit 5280227
Show file tree
Hide file tree
Showing 10 changed files with 239 additions and 61 deletions.
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/neurons/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class CondNeuGroupLTC(HHTypedNeuron, Container, TreeNode):
where :math:`\alpha_{x}` and :math:`\beta_{x}` are rate constants.
.. versionadded:: 2.1.9
Model the conductance-based neuron model.
Modeling the conductance-based neuron model.
Parameters
----------
Expand Down
27 changes: 14 additions & 13 deletions brainpy/_src/dyn/neurons/lif.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

import brainpy.math as bm
from brainpy._src.context import share
from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc
from brainpy._src.dyn.neurons.base import GradNeuDyn
from brainpy._src.initialize import ZeroInit, OneInit
from brainpy._src.integrators import odeint, JointEq
from brainpy.check import is_initializer
from brainpy.types import Shape, ArrayType, Sharding
from brainpy._src.dyn._docs import ref_doc, lif_doc, pneu_doc, dpneu_doc, ltc_doc, if_doc
from brainpy._src.dyn.neurons.base import GradNeuDyn

__all__ = [
'IF',
Expand Down Expand Up @@ -994,6 +994,7 @@ class ExpIFRefLTC(ExpIFLTC):
%s
"""

def __init__(
self,
size: Shape,
Expand Down Expand Up @@ -1221,6 +1222,7 @@ class ExpIFRef(ExpIFRefLTC):
%s
%s
"""

def derivative(self, V, t, I):
exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T)
dvdt = (- (V - self.V_rest) + exp_v + self.R * I) / self.tau
Expand Down Expand Up @@ -1424,7 +1426,8 @@ def update(self, x=None):
x = 0. if x is None else x

# integrate membrane potential
V, w = self.integral(self.V.value, self.w.value, t, x, dt) + self.sum_delta_inputs()
V, w = self.integral(self.V.value, self.w.value, t, x, dt)
V += self.sum_delta_inputs()

# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
Expand Down Expand Up @@ -1756,7 +1759,8 @@ def update(self, x=None):
x = 0. if x is None else x

# integrate membrane potential
V, w = self.integral(self.V.value, self.w.value, t, x, dt) + self.sum_delta_inputs()
V, w = self.integral(self.V.value, self.w.value, t, x, dt)
V += self.sum_delta_inputs()

# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
Expand Down Expand Up @@ -2444,7 +2448,6 @@ class QuaIFRef(QuaIFRefLTC):
%s
"""


def derivative(self, V, t, I):
dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I) / self.tau
return dVdt
Expand Down Expand Up @@ -2633,7 +2636,7 @@ def update(self, x=None):

# integrate membrane potential
V, w = self.integral(self.V.value, self.w.value, t, x, dt)
V = V + self.sum_delta_inputs()
V += self.sum_delta_inputs()

# spike, spiking time, and membrane potential reset
if isinstance(self.mode, bm.TrainingMode):
Expand Down Expand Up @@ -2940,7 +2943,7 @@ def update(self, x=None):

# integrate membrane potential
V, w = self.integral(self.V.value, self.w.value, t, x, dt)
V += self.sum_delta_inputs()
V += self.sum_delta_inputs()

# refractory
refractory = (t - self.t_last_spike) <= self.tau_ref
Expand Down Expand Up @@ -3576,7 +3579,6 @@ class GifRefLTC(GifLTC):
%s
"""


def __init__(
self,
size: Shape,
Expand Down Expand Up @@ -3844,7 +3846,6 @@ class GifRef(GifRefLTC):
%s
"""


def dV(self, V, t, I1, I2, I):
return (- (V - self.V_rest) + self.R * (I + I1 + I2)) / self.tau

Expand Down Expand Up @@ -4495,7 +4496,7 @@ def update(self, x=None):
return super().update(x)


Izhikevich.__doc__ = Izhikevich.__doc__ %(pneu_doc, dpneu_doc)
IzhikevichRefLTC.__doc__ = IzhikevichRefLTC.__doc__ %(pneu_doc, dpneu_doc, ref_doc)
IzhikevichRef.__doc__ = IzhikevichRef.__doc__ %(pneu_doc, dpneu_doc, ref_doc)
IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ %()
Izhikevich.__doc__ = Izhikevich.__doc__ % (pneu_doc, dpneu_doc)
IzhikevichRefLTC.__doc__ = IzhikevichRefLTC.__doc__ % (pneu_doc, dpneu_doc, ref_doc)
IzhikevichRef.__doc__ = IzhikevichRef.__doc__ % (pneu_doc, dpneu_doc, ref_doc)
IzhikevichLTC.__doc__ = IzhikevichLTC.__doc__ % ()
56 changes: 52 additions & 4 deletions brainpy/_src/dyn/projections/align_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,19 @@ def reset_state(self, *args, **kwargs):


class HalfProjAlignPostMg(Projection):
r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
r"""Defining the half part of synaptic projection with the align-post reduction and the automatic synapse merging.
The ``half-part`` means that the model only needs to provide half information needed for a projection,
including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs
the manual providing of the spiking input.
The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
parameters (such like time constants) will also share the same synaptic variables.
All align-post projection models prefer to use the event-driven computation mode. This means that the
``comm`` model should be the event-driven model.
**Code Examples**
Expand Down Expand Up @@ -131,7 +143,22 @@ def update(self, x):


class FullProjAlignPostMg(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
"""Full-chain synaptic projection with the align-post reduction and the automatic synapse merging.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``.
The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
parameters (such like time constants) will also share the same synaptic variables.
All align-post projection models prefer to use the event-driven computation mode. This means that the
``comm`` model should be the event-driven model.
Moreover, it's worth noting that ``FullProjAlignPostMg`` has a different updating order with all align-pre
projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``.
While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``.
**Code Examples**
Expand Down Expand Up @@ -245,7 +272,16 @@ def update(self):


class HalfProjAlignPost(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
"""Defining the half-part of synaptic projection with the align-post reduction.
The ``half-part`` means that the model only needs to provide half information needed for a projection,
including ``comm`` -> ``syn`` -> ``out`` -> ``post``. Therefore, the model's ``update`` function needs
the manual providing of the spiking input.
The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
All align-post projection models prefer to use the event-driven computation mode. This means that the
``comm`` model should be the event-driven model.
To simulate an E/I balanced network:
Expand Down Expand Up @@ -329,7 +365,19 @@ def update(self, x):


class FullProjAlignPost(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
"""Full-chain synaptic projection with the align-post reduction.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``delay`` -> ``comm`` -> ``syn`` -> ``out`` -> ``post``.
The ``align-post`` means that the synaptic variables have the same dimension as the post-synaptic neuron group.
All align-post projection models prefer to use the event-driven computation mode. This means that the
``comm`` model should be the event-driven model.
Moreover, it's worth noting that ``FullProjAlignPost`` has a different updating order with all align-pre
projection models. The updating order of align-post projections is ``spikes`` -> ``comm`` -> ``syn`` -> ``out``.
While, the updating order of all align-pre projection models is usually ``spikes`` -> ``syn`` -> ``comm`` -> ``out``.
To simulate and define an E/I balanced network model:
Expand Down
69 changes: 64 additions & 5 deletions brainpy/_src/dyn/projections/align_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from brainpy._src.delay import (Delay, DelayAccess, init_delay_by_return, register_delay_by_return)
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay, BindCondData)
from .base import _get_return
from .utils import _get_return

__all__ = [
'FullProjAlignPreSDMg', 'FullProjAlignPreDSMg',
Expand Down Expand Up @@ -68,7 +68,22 @@ def reset_state(self, *args, **kwargs):


class FullProjAlignPreSDMg(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
"""Full-chain synaptic projection with the align-pre reduction and synapse+delay updating and merging.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
synapse states to the delay model, and finally computes the synaptic current.
The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
parameters (such like time constants) will also share the same synaptic variables.
Neither ``FullProjAlignPreSDMg`` nor ``FullProjAlignPreDSMg``facilitates the event-driven computation.
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
than the spiking. To facilitate the event-driven computation, please use align post projections.
To simulate an E/I balanced network model:
Expand Down Expand Up @@ -182,7 +197,24 @@ def update(self, x=None):


class FullProjAlignPreDSMg(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating and merging.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``delay`` -> ``syn`` -> ``comm`` -> ``out`` -> ``post``.
Note here, compared to ``FullProjAlignPreSDMg``, the ``delay`` and ``syn`` are exchanged.
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
The ``merging`` means that the same delay model is shared by all synapses, and the synapse model with same
parameters (such like time constants) will also share the same synaptic variables.
Neither ``FullProjAlignPreDSMg`` nor ``FullProjAlignPreSDMg`` facilitates the event-driven computation.
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
than the spiking. To facilitate the event-driven computation, please use align post projections.
To simulate an E/I balanced network model:
Expand Down Expand Up @@ -296,7 +328,20 @@ def update(self):


class FullProjAlignPreSD(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
"""Full-chain synaptic projection with the align-pre reduction and synapse+delay updating.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
The ``synapse+delay updating`` means that the projection first computes the synapse states, then delivers the
synapse states to the delay model, and finally computes the synaptic current.
Neither ``FullProjAlignPreSD`` nor ``FullProjAlignPreDS``facilitates the event-driven computation.
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
than the spiking. To facilitate the event-driven computation, please use align post projections.
To simulate an E/I balanced network model:
Expand Down Expand Up @@ -411,7 +456,21 @@ def update(self, x=None):


class FullProjAlignPreDS(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
"""Full-chain synaptic projection with the align-pre reduction and delay+synapse updating.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``syn`` -> ``delay`` -> ``comm`` -> ``out`` -> ``post``.
Note here, compared to ``FullProjAlignPreSD``, the ``delay`` and ``syn`` are exchanged.
The ``align-pre`` means that the synaptic variables have the same dimension as the pre-synaptic neuron group.
The ``delay+synapse updating`` means that the projection first delivers the pre neuron output (usually the
spiking) to the delay model, then computes the synapse states, and finally computes the synaptic current.
Neither ``FullProjAlignPreDS`` nor ``FullProjAlignPreSD`` facilitates the event-driven computation.
This is because the ``comm`` is computed after the synapse state, which is a floating-point number, rather
than the spiking. To facilitate the event-driven computation, please use align post projections.
To simulate an E/I balanced network model:
Expand Down
63 changes: 35 additions & 28 deletions brainpy/_src/dyn/projections/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@ def __call__(self, *args, **kwargs):


class HalfProjDelta(Projection):
"""Delta synaptic projection.
"""Defining the half-part of the synaptic projection for the Delta synapse model.
The synaptic projection requires the input is the spiking data, otherwise
the synapse is not the Delta synapse model.
The ``half-part`` means that the model only includes ``comm`` -> ``syn`` -> ``out`` -> ``post``.
Therefore, the model's ``update`` function needs the manual providing of the spiking input.
**Model Descriptions**
Expand Down Expand Up @@ -103,7 +109,13 @@ def update(self, x):


class FullProjDelta(Projection):
"""Delta synaptic projection.
"""Full-chain of the synaptic projection for the Delta synapse model.
The synaptic projection requires the input is the spiking data, otherwise
the synapse is not the Delta synapse model.
The ``full-chain`` means that the model needs to provide all information needed for a projection,
including ``pre`` -> ``delay`` -> ``comm`` -> ``post``.
**Model Descriptions**
Expand All @@ -121,36 +133,31 @@ class FullProjDelta(Projection):
**Code Examples**
To simulate an E/I balanced network model:
.. code-block::
class EINet(bp.DynSysGroup):
import brainpy as bp
import brainpy.math as bm
class Net(bp.DynamicalSystem):
def __init__(self):
super().__init__()
self.N = bp.dyn.LifRef(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.))
self.delay = bp.VarDelay(self.N.spike, entries={'I': None})
self.syn1 = bp.dyn.Expon(size=3200, tau=5.)
self.syn2 = bp.dyn.Expon(size=800, tau=10.)
self.E = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6),
out=bp.dyn.COBA(E=0.),
post=self.N)
self.I = bp.dyn.VanillaProj(comm=bp.dnn.JitFPHomoLinear(800, 4000, prob=0.02, weight=6.7),
out=bp.dyn.COBA(E=-80.),
post=self.N)
def update(self, input):
spk = self.delay.at('I')
self.E(self.syn1(spk[:3200]))
self.I(self.syn2(spk[3200:]))
self.delay(self.N(input))
return self.N.spike.value
model = EINet()
indices = bm.arange(1000)
spks = bm.for_loop(lambda i: model.step_run(i, 20.), indices)
bp.visualize.raster_plot(indices, spks, show=True)
self.pre = bp.dyn.PoissonGroup(10, 100.)
self.post = bp.dyn.LifRef(1)
self.syn = bp.dyn.FullProjDelta(self.pre, 0., bp.dnn.Linear(10, 1, bp.init.OneInit(2.)), self.post)
def update(self):
self.syn()
self.pre()
self.post()
return self.post.V.value
net = Net()
indices = bm.arange(1000).to_numpy()
vs = bm.for_loop(net.step_run, indices, progress_bar=True)
bp.visualize.line_plot(indices, vs, show=True)
Args:
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from brainpy.types import ArrayType
from .align_post import (align_post_add_bef_update, )
from .align_pre import (align_pre2_add_bef_update, )
from .base import (_get_return, )
from .utils import (_get_return, )

__all__ = [
'STDP_Song2000',
Expand Down
Loading

0 comments on commit 5280227

Please sign in to comment.