Skip to content

Commit

Permalink
[model] add input variables.
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Sep 16, 2023
1 parent c549082 commit 65fe620
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 7 deletions.
1 change: 1 addition & 0 deletions brainpy/_src/dyn/projections/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .aligns import *
from .conn import *
from .others import *
from .inputs import *
96 changes: 96 additions & 0 deletions brainpy/_src/dyn/projections/inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Optional, Any

from brainpy import math as bm
from brainpy._src.dynsys import Dynamic
from brainpy._src.mixin import SupportAutoDelay
from brainpy.types import Shape

__all__ = [
'InputVar',
]


class InputVar(Dynamic, SupportAutoDelay):
"""Define an input variable.
Example::
import brainpy as bp
class Exponential(bp.Projection):
def __init__(self, pre, post, prob, g_max, tau, E=0.):
super().__init__()
self.proj = bp.dyn.ProjAlignPostMg2(
pre=pre,
delay=None,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(prob, pre=pre.num, post=post.num), g_max),
syn=bp.dyn.Expon.desc(post.num, tau=tau),
out=bp.dyn.COBA.desc(E=E),
post=post,
)
class EINet(bp.DynSysGroup):
def __init__(self, num_exc, num_inh, method='exp_auto'):
super(EINet, self).__init__()
# neurons
pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.,
V_initializer=bp.init.Normal(-55., 2.), method=method)
self.E = bp.dyn.LifRef(num_exc, **pars)
self.I = bp.dyn.LifRef(num_inh, **pars)
# synapses
w_e = 0.6 # excitatory synaptic weight
w_i = 6.7 # inhibitory synaptic weight
# Neurons connect to each other randomly with a connection probability of 2%
self.E2E = Exponential(self.E, self.E, 0.02, g_max=w_e, tau=5., E=0.)
self.E2I = Exponential(self.E, self.I, 0.02, g_max=w_e, tau=5., E=0.)
self.I2E = Exponential(self.I, self.E, 0.02, g_max=w_i, tau=10., E=-80.)
self.I2I = Exponential(self.I, self.I, 0.02, g_max=w_i, tau=10., E=-80.)
# define input variables given to E/I populations
self.Ein = bp.dyn.InputVar(self.E.varshape)
self.Iin = bp.dyn.InputVar(self.I.varshape)
self.E.add_inp_fun('', self.Ein)
self.I.add_inp_fun('', self.Iin)
net = EINet(3200, 800, method='exp_auto') # "method": the numerical integrator method
runner = bp.DSRunner(net, monitors=['E.spike', 'I.spike'], inputs=[('Ein.input', 20.), ('Iin.input', 20.)])
runner.run(100.)
# visualization
bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'],
title='Spikes of Excitatory Neurons', show=True)
bp.visualize.raster_plot(runner.mon.ts, runner.mon['I.spike'],
title='Spikes of Inhibitory Neurons', show=True)
"""
def __init__(
self,
size: Shape,
keep_size: bool = False,
sharding: Optional[Any] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
method: str = 'exp_auto'
):
super().__init__(size=size, keep_size=keep_size, sharding=sharding, name=name, mode=mode, method=method)

self.reset_state(self.mode)

def reset_state(self, batch_or_mode=None):
self.input = self.init_variable(bm.zeros, batch_or_mode)

def update(self, *args, **kwargs):
return self.input.value

def return_info(self):
return self.input

def clear_input(self, *args, **kwargs):
self.reset_state(self.mode)
6 changes: 1 addition & 5 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj):
name : optional, str
The name of the dynamical system.
mode: optional, Mode
The model computation mode. It should be instance of :py:class:`~.Mode`.
The model computation mode. It should be an instance of :py:class:`~.Mode`.
"""

supported_modes: Optional[Sequence[bm.Mode]] = None
Expand Down Expand Up @@ -610,10 +610,6 @@ def reset_state(self, *args, **kwargs):
else:
raise ValueError('Do not implement the reset_state() function.')

def clear_input(self, *args, **kwargs):
"""Empty function of clearing inputs."""
pass


class Dynamic(DynamicalSystem):
"""Base class to model dynamics.
Expand Down
4 changes: 3 additions & 1 deletion brainpy/dyn/projections.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,6 @@
PoissonInput as PoissonInput,
)


from brainpy._src.dyn.projections.inputs import (
InputVar,
)
29 changes: 28 additions & 1 deletion docs/apis/brainpy.dyn.projections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ Synaptic Projections



Reduced Projections
-------------------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

VanillaProj
ProjAlignPostMg1
ProjAlignPostMg2
ProjAlignPost1
Expand All @@ -20,5 +22,30 @@ Synaptic Projections
ProjAlignPreMg2
ProjAlignPre1
ProjAlignPre2



Projections
-----------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst

VanillaProj
SynConn



Inputs
------

.. autosummary::
:toctree: generated/
:nosignatures:
:template: classtemplate.rst


PoissonInput
InputVar

0 comments on commit 65fe620

Please sign in to comment.