Skip to content

Commit

Permalink
Merge pull request #491 from chaoming0625/master
Browse files Browse the repository at this point in the history
update STDP
  • Loading branch information
chaoming0625 authored Sep 11, 2023
2 parents 4585f20 + 46bd161 commit 4a64366
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 128 deletions.
63 changes: 13 additions & 50 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,18 @@
from typing import Optional, Callable, Union

from brainpy.types import ArrayType
from brainpy import math as bm, check
from brainpy._src.delay import DelayAccess, delay_identifier, init_delay_by_return
from brainpy._src.dyn.synapses.abstract_models import Expon
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo,
AutoDelaySupp, BindCondData, AlignPost, SupportSTDP)
from brainpy._src.initialize import parameter
from brainpy._src.dyn.synapses.abstract_models import Expon
from brainpy._src.mixin import (JointType, ParamDescInit, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP)
from brainpy.types import ArrayType
from .aligns import _AlignPost, _AlignPreMg, _get_return

__all__ = [
'STDP_Song2000'
'STDP_Song2000',
]

class _AlignPre(DynamicalSystem):
def __init__(self, syn, delay=None):
super().__init__()
self.syn = syn
self.delay = delay

def update(self, x):
if self.delay is None:
return x >> self.syn
else:
return x >> self.syn >> self.delay


class _AlignPost(DynamicalSystem):
def __init__(self,
syn: Callable,
out: JointType[DynamicalSystem, BindCondData]):
super().__init__()
self.syn = syn
self.out = out

def update(self, *args, **kwargs):
self.out.bind_cond(self.syn(*args, **kwargs))


class _AlignPreMg(DynamicalSystem):
def __init__(self, access, syn):
super().__init__()
self.access = access
self.syn = syn

def update(self, *args, **kwargs):
return self.syn(self.access())


def _get_return(return_info):
if isinstance(return_info, bm.Variable):
return return_info.value
elif isinstance(return_info, ReturnInfo):
return return_info.get_data()
else:
raise NotImplementedError


class STDP_Song2000(Projection):
r"""Synaptic output with spike-time-dependent plasticity.
Expand Down Expand Up @@ -135,9 +92,10 @@ class STDP_Song2000(Projection):
A2: float, ArrayType, Callable. The increment of :math:`A_{post}` produced by a spike.
%s
"""

def __init__(
self,
pre: JointType[DynamicalSystem, AutoDelaySupp],
pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
syn: ParamDescInit[DynamicalSystem],
comm: DynamicalSystem,
Expand All @@ -148,14 +106,15 @@ def __init__(
tau_t: Union[float, ArrayType, Callable] = 33.7,
A1: Union[float, ArrayType, Callable] = 0.96,
A2: Union[float, ArrayType, Callable] = 0.53,
# others
out_label: Optional[str] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)

# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, ParamDescInit[DynamicalSystem])
check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP])
check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
Expand Down Expand Up @@ -252,20 +211,24 @@ def calculate_trace(
return delay_cls.get_bef_update(_syn_id).syn

def update(self):
# pre spikes, and pre-synaptic variables
if issubclass(self.syn.cls, AlignPost):
pre_spike = self.refs['delay'].at(self.name)
x = pre_spike
else:
pre_spike = self.refs['delay'].access()
x = _get_return(self.refs['syn'].return_info())

# post spikes
post_spike = self.refs['post'].spike

# weight updates
Apre = self.refs['pre_trace'].g
Apost = self.refs['post_trace'].g
delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike)
self.comm.update_STDP(delta_w)

# currents
current = self.comm(x)
if issubclass(self.syn.cls, AlignPost):
self.refs['syn'].add_current(current) # synapse post current
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from brainpy import tools, math as bm
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import SupportAutoDelay, Container, ReceiveInputProj, DelayRegister, global_delay_data
from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
from brainpy._src.deprecations import _update_deprecate_msg
Expand Down Expand Up @@ -70,7 +70,7 @@ def update(self, x):
return func


class DynamicalSystem(bm.BrainPyObject, DelayRegister, ReceiveInputProj):
class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj):
"""Base Dynamical System class.
.. note::
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def fun(self):
# that has been created.
a = self.tracing_variable('a', bm.zeros, (10,))
.. versionadded:: 2.4.5
Args:
name: str. The variable name.
init: callable, Array. The data to be initialized as a ``Variable``.
Expand Down
13 changes: 12 additions & 1 deletion brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
__all__ = [
'RandomState', 'Generator', 'DEFAULT',

'seed', 'default_rng', 'split_key',
'seed', 'default_rng', 'split_key', 'split_keys',

# numpy compatibility
'rand', 'randint', 'random_integers', 'randn', 'random',
Expand Down Expand Up @@ -1258,6 +1258,8 @@ def split_keys(n):
internally by `pmap` and `vmap` to ensure that random numbers
are different in parallel threads.
.. versionadded:: 2.4.5
Parameters
----------
n : int
Expand All @@ -1267,6 +1269,15 @@ def split_keys(n):


def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
"""Clone the random state according to the given setting.
Args:
seed_or_key: The seed (an integer) or the random key.
clone: Bool. Whether clone the default random state.
Returns:
The random state.
"""
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
Expand Down
138 changes: 74 additions & 64 deletions brainpy/_src/mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numbers
import sys
import warnings
from dataclasses import dataclass
from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any
from typing import (_SpecialForm, _type_check, _remove_dups_flatten)
Expand Down Expand Up @@ -28,12 +27,15 @@
'ParamDesc',
'ParamDescInit',
'AlignPost',
'SupportAutoDelay',
'Container',
'TreeNode',
'BindCondData',
'JointType',
'SupportSTDP',
'SupportAutoDelay',
'SupportInputProj',
'SupportOnline',
'SupportOffline',
]

global_delay_data = dict()
Expand All @@ -47,59 +49,6 @@ class MixIn(object):
pass


class ReceiveInputProj(MixIn):
"""The :py:class:`~.MixIn` that receives the input projections.
Note that the subclass should define a ``cur_inputs`` attribute.
"""
cur_inputs: bm.node_dict

def add_inp_fun(self, key: Any, fun: Callable):
"""Add an input function.
Args:
key: The dict key.
fun: The function to generate inputs.
"""
if not callable(fun):
raise TypeError('Must be a function.')
if key in self.cur_inputs:
raise ValueError(f'Key "{key}" has been defined and used.')
self.cur_inputs[key] = fun

def get_inp_fun(self, key):
"""Get the input function.
Args:
key: The key.
Returns:
The input function which generates currents.
"""
return self.cur_inputs.get(key)

def sum_inputs(self, *args, init=0., label=None, **kwargs):
"""Summarize all inputs by the defined input functions ``.cur_inputs``.
Args:
*args: The arguments for input functions.
init: The initial input data.
**kwargs: The arguments for input functions.
Returns:
The total currents.
"""
if label is None:
for key, out in self.cur_inputs.items():
init = init + out(*args, **kwargs)
else:
for key, out in self.cur_inputs.items():
if key.startswith(label + ' // '):
init = init + out(*args, **kwargs)
return init


class ParamDesc(MixIn):
""":py:class:`~.MixIn` indicates the function for describing initialization parameters.
Expand Down Expand Up @@ -208,13 +157,6 @@ def get_data(self):
return init


class SupportAutoDelay(MixIn):
"""``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""

def return_info(self) -> Union[bm.Variable, ReturnInfo]:
raise NotImplementedError('Must implement the "return_info()" function.')


class Container(MixIn):
"""Container :py:class:`~.MixIn` which wrap a group of objects.
"""
Expand Down Expand Up @@ -550,8 +492,71 @@ def get_delay_var(self, name):
return global_delay_data[name]


class SupportInputProj(MixIn):
"""The :py:class:`~.MixIn` that receives the input projections.
Note that the subclass should define a ``cur_inputs`` attribute.
"""
cur_inputs: bm.node_dict

def add_inp_fun(self, key: Any, fun: Callable):
"""Add an input function.
Args:
key: The dict key.
fun: The function to generate inputs.
"""
if not callable(fun):
raise TypeError('Must be a function.')
if key in self.cur_inputs:
raise ValueError(f'Key "{key}" has been defined and used.')
self.cur_inputs[key] = fun

def get_inp_fun(self, key):
"""Get the input function.
Args:
key: The key.
Returns:
The input function which generates currents.
"""
return self.cur_inputs.get(key)

def sum_inputs(self, *args, init=0., label=None, **kwargs):
"""Summarize all inputs by the defined input functions ``.cur_inputs``.
Args:
*args: The arguments for input functions.
init: The initial input data.
**kwargs: The arguments for input functions.
Returns:
The total currents.
"""
if label is None:
for key, out in self.cur_inputs.items():
init = init + out(*args, **kwargs)
else:
for key, out in self.cur_inputs.items():
if key.startswith(label + ' // '):
init = init + out(*args, **kwargs)
return init


class SupportAutoDelay(MixIn):
"""``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""

def return_info(self) -> Union[bm.Variable, ReturnInfo]:
raise NotImplementedError('Must implement the "return_info()" function.')


class SupportOnline(MixIn):
""":py:class:`~.MixIn` to support the online training methods."""
""":py:class:`~.MixIn` to support the online training methods.
.. versionadded:: 2.4.5
"""

online_fit_by: Optional # methods for online fitting

Expand All @@ -563,7 +568,10 @@ def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]):


class SupportOffline(MixIn):
""":py:class:`~.MixIn` to support the offline training methods."""
""":py:class:`~.MixIn` to support the offline training methods.
.. versionadded:: 2.4.5
"""

offline_fit_by: Optional # methods for offline fitting

Expand All @@ -573,6 +581,8 @@ def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]):

class BindCondData(MixIn):
"""Bind temporary conductance data.
"""
_conductance: Optional

Expand Down
4 changes: 2 additions & 2 deletions brainpy/mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

from brainpy._src.mixin import (
MixIn as MixIn,
ReceiveInputProj as ReceiveInputProj,
SupportInputProj as SupportInputProj,
AlignPost as AlignPost,
AutoDelaySupp as AutoDelaySupp,
SupportAutoDelay as SupportAutoDelay,
ParamDesc as ParamDesc,
ParamDescInit as ParamDescInit,
BindCondData as BindCondData,
Expand Down
Loading

0 comments on commit 4a64366

Please sign in to comment.