Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update STDP #491

Merged
merged 3 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 13 additions & 50 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,18 @@
from typing import Optional, Callable, Union

from brainpy.types import ArrayType
from brainpy import math as bm, check
from brainpy._src.delay import DelayAccess, delay_identifier, init_delay_by_return
from brainpy._src.dyn.synapses.abstract_models import Expon
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo,
AutoDelaySupp, BindCondData, AlignPost, SupportSTDP)
from brainpy._src.initialize import parameter
from brainpy._src.dyn.synapses.abstract_models import Expon
from brainpy._src.mixin import (JointType, ParamDescInit, SupportAutoDelay, BindCondData, AlignPost, SupportSTDP)
from brainpy.types import ArrayType
from .aligns import _AlignPost, _AlignPreMg, _get_return

__all__ = [
'STDP_Song2000'
'STDP_Song2000',
]

class _AlignPre(DynamicalSystem):
def __init__(self, syn, delay=None):
super().__init__()
self.syn = syn
self.delay = delay

def update(self, x):
if self.delay is None:
return x >> self.syn
else:
return x >> self.syn >> self.delay


class _AlignPost(DynamicalSystem):
def __init__(self,
syn: Callable,
out: JointType[DynamicalSystem, BindCondData]):
super().__init__()
self.syn = syn
self.out = out

def update(self, *args, **kwargs):
self.out.bind_cond(self.syn(*args, **kwargs))


class _AlignPreMg(DynamicalSystem):
def __init__(self, access, syn):
super().__init__()
self.access = access
self.syn = syn

def update(self, *args, **kwargs):
return self.syn(self.access())


def _get_return(return_info):
if isinstance(return_info, bm.Variable):
return return_info.value
elif isinstance(return_info, ReturnInfo):
return return_info.get_data()
else:
raise NotImplementedError


class STDP_Song2000(Projection):
r"""Synaptic output with spike-time-dependent plasticity.
Expand Down Expand Up @@ -135,9 +92,10 @@ class STDP_Song2000(Projection):
A2: float, ArrayType, Callable. The increment of :math:`A_{post}` produced by a spike.
%s
"""

def __init__(
self,
pre: JointType[DynamicalSystem, AutoDelaySupp],
pre: JointType[DynamicalSystem, SupportAutoDelay],
delay: Union[None, int, float],
syn: ParamDescInit[DynamicalSystem],
comm: DynamicalSystem,
Expand All @@ -148,14 +106,15 @@ def __init__(
tau_t: Union[float, ArrayType, Callable] = 33.7,
A1: Union[float, ArrayType, Callable] = 0.96,
A2: Union[float, ArrayType, Callable] = 0.53,
# others
out_label: Optional[str] = None,
name: Optional[str] = None,
mode: Optional[bm.Mode] = None,
):
super().__init__(name=name, mode=mode)

# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp])
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, ParamDescInit[DynamicalSystem])
check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP])
check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]])
Expand Down Expand Up @@ -252,20 +211,24 @@ def calculate_trace(
return delay_cls.get_bef_update(_syn_id).syn

def update(self):
# pre spikes, and pre-synaptic variables
if issubclass(self.syn.cls, AlignPost):
pre_spike = self.refs['delay'].at(self.name)
x = pre_spike
else:
pre_spike = self.refs['delay'].access()
x = _get_return(self.refs['syn'].return_info())

# post spikes
post_spike = self.refs['post'].spike

# weight updates
Apre = self.refs['pre_trace'].g
Apost = self.refs['post_trace'].g
delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike)
self.comm.update_STDP(delta_w)

# currents
current = self.comm(x)
if issubclass(self.syn.cls, AlignPost):
self.refs['syn'].add_current(current) # synapse post current
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from brainpy import tools, math as bm
from brainpy._src.initialize import parameter, variable_
from brainpy._src.mixin import SupportAutoDelay, Container, ReceiveInputProj, DelayRegister, global_delay_data
from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.types import ArrayType, Shape
from brainpy._src.deprecations import _update_deprecate_msg
Expand Down Expand Up @@ -70,7 +70,7 @@ def update(self, x):
return func


class DynamicalSystem(bm.BrainPyObject, DelayRegister, ReceiveInputProj):
class DynamicalSystem(bm.BrainPyObject, DelayRegister, SupportInputProj):
"""Base Dynamical System class.

.. note::
Expand Down
2 changes: 2 additions & 0 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def fun(self):
# that has been created.
a = self.tracing_variable('a', bm.zeros, (10,))

.. versionadded:: 2.4.5

Args:
name: str. The variable name.
init: callable, Array. The data to be initialized as a ``Variable``.
Expand Down
13 changes: 12 additions & 1 deletion brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
__all__ = [
'RandomState', 'Generator', 'DEFAULT',

'seed', 'default_rng', 'split_key',
'seed', 'default_rng', 'split_key', 'split_keys',

# numpy compatibility
'rand', 'randint', 'random_integers', 'randn', 'random',
Expand Down Expand Up @@ -1258,6 +1258,8 @@ def split_keys(n):
internally by `pmap` and `vmap` to ensure that random numbers
are different in parallel threads.

.. versionadded:: 2.4.5

Parameters
----------
n : int
Expand All @@ -1267,6 +1269,15 @@ def split_keys(n):


def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState:
"""Clone the random state according to the given setting.

Args:
seed_or_key: The seed (an integer) or the random key.
clone: Bool. Whether clone the default random state.

Returns:
The random state.
"""
if seed_or_key is None:
return DEFAULT.clone() if clone else DEFAULT
else:
Expand Down
138 changes: 74 additions & 64 deletions brainpy/_src/mixin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numbers
import sys
import warnings
from dataclasses import dataclass
from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any
from typing import (_SpecialForm, _type_check, _remove_dups_flatten)
Expand Down Expand Up @@ -28,12 +27,15 @@
'ParamDesc',
'ParamDescInit',
'AlignPost',
'SupportAutoDelay',
'Container',
'TreeNode',
'BindCondData',
'JointType',
'SupportSTDP',
'SupportAutoDelay',
'SupportInputProj',
'SupportOnline',
'SupportOffline',
]

global_delay_data = dict()
Expand All @@ -47,59 +49,6 @@ class MixIn(object):
pass


class ReceiveInputProj(MixIn):
"""The :py:class:`~.MixIn` that receives the input projections.

Note that the subclass should define a ``cur_inputs`` attribute.

"""
cur_inputs: bm.node_dict

def add_inp_fun(self, key: Any, fun: Callable):
"""Add an input function.

Args:
key: The dict key.
fun: The function to generate inputs.
"""
if not callable(fun):
raise TypeError('Must be a function.')
if key in self.cur_inputs:
raise ValueError(f'Key "{key}" has been defined and used.')
self.cur_inputs[key] = fun

def get_inp_fun(self, key):
"""Get the input function.

Args:
key: The key.

Returns:
The input function which generates currents.
"""
return self.cur_inputs.get(key)

def sum_inputs(self, *args, init=0., label=None, **kwargs):
"""Summarize all inputs by the defined input functions ``.cur_inputs``.

Args:
*args: The arguments for input functions.
init: The initial input data.
**kwargs: The arguments for input functions.

Returns:
The total currents.
"""
if label is None:
for key, out in self.cur_inputs.items():
init = init + out(*args, **kwargs)
else:
for key, out in self.cur_inputs.items():
if key.startswith(label + ' // '):
init = init + out(*args, **kwargs)
return init


class ParamDesc(MixIn):
""":py:class:`~.MixIn` indicates the function for describing initialization parameters.

Expand Down Expand Up @@ -208,13 +157,6 @@ def get_data(self):
return init


class SupportAutoDelay(MixIn):
"""``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""

def return_info(self) -> Union[bm.Variable, ReturnInfo]:
raise NotImplementedError('Must implement the "return_info()" function.')


class Container(MixIn):
"""Container :py:class:`~.MixIn` which wrap a group of objects.
"""
Expand Down Expand Up @@ -550,8 +492,71 @@ def get_delay_var(self, name):
return global_delay_data[name]


class SupportInputProj(MixIn):
"""The :py:class:`~.MixIn` that receives the input projections.

Note that the subclass should define a ``cur_inputs`` attribute.

"""
cur_inputs: bm.node_dict

def add_inp_fun(self, key: Any, fun: Callable):
"""Add an input function.

Args:
key: The dict key.
fun: The function to generate inputs.
"""
if not callable(fun):
raise TypeError('Must be a function.')
if key in self.cur_inputs:
raise ValueError(f'Key "{key}" has been defined and used.')
self.cur_inputs[key] = fun

def get_inp_fun(self, key):
"""Get the input function.

Args:
key: The key.

Returns:
The input function which generates currents.
"""
return self.cur_inputs.get(key)

def sum_inputs(self, *args, init=0., label=None, **kwargs):
"""Summarize all inputs by the defined input functions ``.cur_inputs``.

Args:
*args: The arguments for input functions.
init: The initial input data.
**kwargs: The arguments for input functions.

Returns:
The total currents.
"""
if label is None:
for key, out in self.cur_inputs.items():
init = init + out(*args, **kwargs)
else:
for key, out in self.cur_inputs.items():
if key.startswith(label + ' // '):
init = init + out(*args, **kwargs)
return init


class SupportAutoDelay(MixIn):
"""``MixIn`` to support the automatic delay in synaptic projection :py:class:`~.SynProj`."""

def return_info(self) -> Union[bm.Variable, ReturnInfo]:
raise NotImplementedError('Must implement the "return_info()" function.')


class SupportOnline(MixIn):
""":py:class:`~.MixIn` to support the online training methods."""
""":py:class:`~.MixIn` to support the online training methods.

.. versionadded:: 2.4.5
"""

online_fit_by: Optional # methods for online fitting

Expand All @@ -563,7 +568,10 @@ def online_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]):


class SupportOffline(MixIn):
""":py:class:`~.MixIn` to support the offline training methods."""
""":py:class:`~.MixIn` to support the offline training methods.

.. versionadded:: 2.4.5
"""

offline_fit_by: Optional # methods for offline fitting

Expand All @@ -573,6 +581,8 @@ def offline_fit(self, target: ArrayType, fit_record: Dict[str, ArrayType]):

class BindCondData(MixIn):
"""Bind temporary conductance data.


"""
_conductance: Optional

Expand Down
4 changes: 2 additions & 2 deletions brainpy/mixin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

from brainpy._src.mixin import (
MixIn as MixIn,
ReceiveInputProj as ReceiveInputProj,
SupportInputProj as SupportInputProj,
AlignPost as AlignPost,
AutoDelaySupp as AutoDelaySupp,
SupportAutoDelay as SupportAutoDelay,
ParamDesc as ParamDesc,
ParamDescInit as ParamDescInit,
BindCondData as BindCondData,
Expand Down
Loading