From f3d184320b0977a4249a0a3eec08ed6c96d36415 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 24 Oct 2023 19:06:00 +0800 Subject: [PATCH 1/5] update state resetting --- brainpy/_src/dynsys.py | 37 +++++++------------------------------ 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index e79f0a2df..e99610829 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import collections -import gc import inspect import warnings from typing import Union, Dict, Callable, Sequence, Optional, Any @@ -12,7 +11,7 @@ from brainpy._src.context import share from brainpy._src.deprecations import _update_deprecate_msg from brainpy._src.initialize import parameter, variable_ -from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, global_delay_data +from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError from brainpy.types import ArrayType, Shape @@ -27,6 +26,8 @@ 'Dynamic', 'Projection', ] + +IonChaDyn = None SLICE_VARS = 'slice_vars' @@ -163,7 +164,10 @@ def reset(self, *args, include_self: bool = False, **kwargs): include_self: bool. Reset states including the node self. Please turn on this if the node has implemented its ".reset_state()" function. """ - child_nodes = self.nodes(include_self=include_self).subset(DynamicalSystem).unique() + global IonChaDyn + if IonChaDyn is None: + from brainpy._src.dyn.base import IonChaDyn + child_nodes = self.nodes(include_self=include_self).subset(DynamicalSystem).not_subset(IonChaDyn).unique() for node in child_nodes.values(): node.reset_state(*args, **kwargs) @@ -353,29 +357,6 @@ def __call__(self, *args, **kwargs): model(ret) return ret - def __del__(self): - """Function for handling `del` behavior. - - This function is used to pop out the variables which registered in global delay data. - """ - try: - if hasattr(self, 'local_delay_vars'): - for key in tuple(self.local_delay_vars.keys()): - val = global_delay_data.pop(key) - del val - val = self.local_delay_vars.pop(key) - del val - if hasattr(self, 'implicit_nodes'): - for key in tuple(self.implicit_nodes.keys()): - del self.implicit_nodes[key] - if hasattr(self, 'implicit_vars'): - for key in tuple(self.implicit_vars.keys()): - del self.implicit_vars[key] - for key in tuple(self.__dict__.keys()): - del self.__dict__[key] - finally: - gc.collect() - def __rrshift__(self, other): """Support using right shift operator to call modules. @@ -434,10 +415,6 @@ def update(self, *args, **kwargs): for node in nodes.not_subset(Dynamic).not_subset(Projection).values(): node() - # update delays - # TODO: Will be deprecated in the future - self.update_local_delays(nodes) - class Network(DynSysGroup): """A group of :py:class:`~.DynamicalSystem`s which defines the nodes and edges in a network. From f9514ad026d7870657a7e4012aeed4f39a427981 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 24 Oct 2023 19:30:43 +0800 Subject: [PATCH 2/5] [delay] integrate previous delay API into the new version of brainpy update --- brainpy/_src/delay.py | 10 +- brainpy/_src/dyn/synapses/delay_couplings.py | 12 +- .../_src/dynold/synapses/abstract_models.py | 2 +- brainpy/_src/mixin.py | 190 +++++------------- brainpy/mixin.py | 6 +- tests/simulation/test_net_rate_FHN.py | 1 + 6 files changed, 70 insertions(+), 151 deletions(-) diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index 0c0016155..d6cdfd682 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -369,7 +369,7 @@ def update( else: self.data[0] = latest_value - def reset_state(self, batch_size: int = None): + def reset_state(self, batch_size: int = None, **kwargs): """Reset the delay data. """ # initialize delay data @@ -439,7 +439,7 @@ def __init__( name=name, mode=mode) - def reset_state(self, batch_size: int = None): + def reset_state(self, batch_size: int = None, **kwargs): """Reset the delay data. """ self.target.value = variable_(self.target_init, self.target.size_without_batch, batch_size) @@ -476,9 +476,9 @@ def reset_state(self, *args, **kwargs): pass -def init_delay_by_return(info: Union[bm.Variable, ReturnInfo]) -> Delay: +def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_data=None) -> Delay: if isinstance(info, bm.Variable): - return VarDelay(info) + return VarDelay(info, init=initial_delay_data) elif isinstance(info, ReturnInfo): # batch size @@ -510,6 +510,6 @@ def init_delay_by_return(info: Union[bm.Variable, ReturnInfo]) -> Delay: # variable target = bm.Variable(init, batch_axis=batch_axis, axis_names=info.axis_names) - return DataDelay(target, data_init=info.data) + return DataDelay(target, data_init=info.data, init=initial_delay_data) else: raise TypeError diff --git a/brainpy/_src/dyn/synapses/delay_couplings.py b/brainpy/_src/dyn/synapses/delay_couplings.py index a4ecaa67c..ef43139da 100644 --- a/brainpy/_src/dyn/synapses/delay_couplings.py +++ b/brainpy/_src/dyn/synapses/delay_couplings.py @@ -191,7 +191,7 @@ def __init__( def update(self): # delays axis = self.coupling_var1.ndim - delay_var: bm.LengthDelay = self.get_delay_var(f'delay_{id(self.delay_var)}')[0] + delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}') if self.delay_steps is None: diffusive = (jnp.expand_dims(self.coupling_var1.value, axis=axis) - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) @@ -201,13 +201,13 @@ def update(self): indices = (slice(None, None, None), jnp.arange(self.coupling_var1.size),) else: indices = (jnp.arange(self.coupling_var1.size),) - f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (..., pre.num) + f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (..., pre.num) delays = f(self.delay_steps) # (..., post.num, pre.num) diffusive = (jnp.moveaxis(bm.as_jax(delays), axis - 1, axis) - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) elif self.delay_type == 'int': - delayed_data = delay_var(self.delay_steps) # (..., pre.num) + delayed_data = delay_var.retrieve(self.delay_steps) # (..., pre.num) diffusive = (jnp.expand_dims(delayed_data, axis=axis) - jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num) diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1) @@ -276,7 +276,7 @@ def __init__( def update(self): # delay function axis = self.coupling_var.ndim - delay_var: bm.LengthDelay = self.get_delay_var(f'delay_{id(self.delay_var)}')[0] + delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}') if self.delay_steps is None: additive = self.coupling_var @ self.conn_mat elif self.delay_type == 'array': @@ -284,11 +284,11 @@ def update(self): indices = (slice(None, None, None), jnp.arange(self.coupling_var.size),) else: indices = (jnp.arange(self.coupling_var.size),) - f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (.., pre.num,) + f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (.., pre.num,) delays = f(self.delay_steps) # (..., post.num, pre.num) additive = (self.conn_mat * jnp.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1) elif self.delay_type == 'int': - delayed_var = delay_var(self.delay_steps) # (..., pre.num) + delayed_var = delay_var.retrieve(self.delay_steps) # (..., pre.num) additive = delayed_var @ self.conn_mat else: raise ValueError diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py index 60af8ee89..aef74a756 100644 --- a/brainpy/_src/dynold/synapses/abstract_models.py +++ b/brainpy/_src/dynold/synapses/abstract_models.py @@ -124,7 +124,7 @@ def reset_state(self, batch_size=None): def update(self, pre_spike=None): # pre-synaptic spikes if pre_spike is None: - pre_spike = self.get_delay_data(f"{self.pre.name}.spike", delay_step=self.delay_step) + pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step) pre_spike = bm.as_jax(pre_spike) if self.stop_spike_gradient: pre_spike = jax.lax.stop_gradient(pre_spike) diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 37d3ca3b7..75249692a 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -1,16 +1,14 @@ import numbers import sys +import warnings from dataclasses import dataclass from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any from typing import (_SpecialForm, _type_check, _remove_dups_flatten) import jax -import jax.numpy as jnp -import numpy as np from brainpy import math as bm, tools from brainpy._src.math.object_transform.naming import get_unique_name -from brainpy._src.initialize import parameter from brainpy.types import ArrayType if sys.version_info.minor > 8: @@ -26,6 +24,7 @@ 'MixIn', 'ParamDesc', 'ParamDescInit', + 'DelayRegister', 'AlignPost', 'Container', 'TreeNode', @@ -38,7 +37,19 @@ 'SupportOffline', ] -global_delay_data = dict() + +def _get_delay_tool(): + global delay_identifier, init_delay_by_return + if init_delay_by_return is None: from brainpy._src.delay import init_delay_by_return + if delay_identifier is None: from brainpy._src.delay import delay_identifier + return delay_identifier, init_delay_by_return + + +def _get_dynsys(): + global DynamicalSystem + if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem + return DynamicalSystem + class MixIn(object): @@ -272,28 +283,28 @@ def check_hierarchy(self, root, leaf): class DelayRegister(MixIn): - local_delay_vars: bm.node_dict def register_delay_at( self, name: str, delay: Union[numbers.Number, ArrayType] = None, + target: Optional[bm.Variable] = None ): """Register relay at the given delay time. Args: name: str. The identifier of the delay. delay: The delay time. + target: Variable. The delay target variable. """ - global delay_identifier, init_delay_by_return, DynamicalSystem - if init_delay_by_return is None: from brainpy._src.delay import init_delay_by_return - if delay_identifier is None: from brainpy._src.delay import delay_identifier - if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem - - assert isinstance(self, SupportAutoDelay), f'self must be an instance of {SupportAutoDelay.__name__}' + delay_identifier, init_delay_by_return = _get_delay_tool() + DynamicalSystem = _get_dynsys() assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' if not self.has_aft_update(delay_identifier): - self.add_aft_update(delay_identifier, init_delay_by_return(self.return_info())) + if target is None: + assert isinstance(self, SupportAutoDelay), f'self must be an instance of {SupportAutoDelay.__name__}' + target = self.return_info() + self.add_aft_update(delay_identifier, init_delay_by_return(target)) delay_cls = self.get_aft_update(delay_identifier) delay_cls.register_entry(name, delay) @@ -324,71 +335,23 @@ def register_delay( initial_delay_data: The initializer for the delay data. Returns: - delay_step: The number of the delay steps. + delay_pos: The position of the delay. """ - # warnings.warn('\n' - # 'Starting from brainpy>=2.4.4, instead of ".register_delay()", ' - # 'we recommend the user to first use ".register_delay_at()", ' - # 'then use ".get_delay_at()" to access the delayed data. ' - # '".register_delay()" will be removed after 2.5.0.', - # UserWarning) - - # delay steps - if delay_step is None: - delay_type = 'none' - elif isinstance(delay_step, (int, np.integer, jnp.integer)): - delay_type = 'homo' - elif isinstance(delay_step, (bm.ndarray, jnp.ndarray, np.ndarray)): - if delay_step.size == 1 and delay_step.ndim == 0: - delay_type = 'homo' - else: - delay_type = 'heter' - delay_step = bm.asarray(delay_step) - elif callable(delay_step): - delay_step = parameter(delay_step, delay_target.shape, allow_none=False) - delay_type = 'heter' - else: - raise ValueError(f'Unknown "delay_steps" type {type(delay_step)}, only support ' - f'integer, array of integers, callable function, brainpy.init.Initializer.') - if delay_type == 'heter': - if delay_step.dtype not in [bm.int32, bm.int64]: - raise ValueError('Only support delay steps of int32, int64. If your ' - 'provide delay time length, please divide the "dt" ' - 'then provide us the number of delay steps.') - if delay_target.shape[0] != delay_step.shape[0]: - raise ValueError(f'Shape is mismatched: {delay_target.shape[0]} != {delay_step.shape[0]}') - if delay_type != 'none': - max_delay_step = int(bm.max(delay_step)) - - # delay target - if delay_type != 'none': - if not isinstance(delay_target, bm.Variable): - raise ValueError(f'"delay_target" must be an instance of Variable, but we got {type(delay_target)}') - - # delay variable - # TODO - if delay_type != 'none': - if identifier not in global_delay_data: - delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data) - global_delay_data[identifier] = (delay, delay_target) - self.local_delay_vars[identifier] = delay - else: - delay = global_delay_data[identifier][0] - if delay is None: - delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data) - global_delay_data[identifier] = (delay, delay_target) - self.local_delay_vars[identifier] = delay - elif delay.num_delay_step - 1 < max_delay_step: - global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data) - else: - if identifier not in global_delay_data: - global_delay_data[identifier] = (None, delay_target) - return delay_step + _delay_identifier, _init_delay_by_return = _get_delay_tool() + DynamicalSystem = _get_dynsys() + assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' + _delay_identifier = _delay_identifier + identifier + if not self.has_aft_update(_delay_identifier): + self.add_aft_update(_delay_identifier, _init_delay_by_return(delay_target, initial_delay_data)) + delay_cls = self.get_aft_update(_delay_identifier) + name = get_unique_name('delay') + delay_cls.register_entry(name, delay_step) + return name def get_delay_data( self, identifier: str, - delay_step: Optional[Union[int, bm.Array, jax.Array]], + delay_pos: str, *indices: Union[int, slice, bm.Array, jax.Array], ): """Get delay data according to the provided delay steps. @@ -397,7 +360,7 @@ def get_delay_data( ---------- identifier: str The delay variable name. - delay_step: Optional, int, ArrayType + delay_pos: str The delay length. indices: optional, int, slice, ArrayType The indices of the delay. @@ -407,34 +370,10 @@ def get_delay_data( delay_data: ArrayType The delay data at the given time. """ - # warnings.warn('\n' - # 'Starting from brainpy>=2.4.4, instead of ".get_delay_data()", ' - # 'we recommend the user to first use ".register_delay_at()", ' - # 'then use ".get_delay_at()" to access the delayed data.' - # '".get_delay_data()" will be removed after 2.5.0.', - # UserWarning) - - if delay_step is None: - return global_delay_data[identifier][1].value - - if identifier in global_delay_data: - if bm.ndim(delay_step) == 0: - return global_delay_data[identifier][0](delay_step, *indices) - else: - if len(indices) == 0: - indices = (bm.arange(delay_step.size),) - return global_delay_data[identifier][0](delay_step, *indices) - - elif identifier in self.local_delay_vars: - if bm.ndim(delay_step) == 0: - return self.local_delay_vars[identifier](delay_step) - else: - if len(indices) == 0: - indices = (bm.arange(delay_step.size),) - return self.local_delay_vars[identifier](delay_step, *indices) - - else: - raise ValueError(f'{identifier} is not defined in delay variables.') + _delay_identifier, _init_delay_by_return = _get_delay_tool() + _delay_identifier = _delay_identifier + identifier + delay_cls = self.get_aft_update(_delay_identifier) + return delay_cls.at(delay_pos, *indices) def update_local_delays(self, nodes: Union[Sequence, Dict] = None): """Update local delay variables. @@ -448,22 +387,8 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None): nodes: sequence, dict The nodes to update their delay variables. """ - global DynamicalSystem - if DynamicalSystem is None: - from brainpy._src.dynsys import DynamicalSystem - - # update delays - if nodes is None: - nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values()) - elif isinstance(nodes, dict): - nodes = tuple(nodes.values()) - if not isinstance(nodes, (tuple, list)): - nodes = (nodes,) - for node in nodes: - for name in node.local_delay_vars: - delay = global_delay_data[name][0] - target = global_delay_data[name][1] - delay.update(target.value) + warnings.warn('.update_local_delays() has been removed since brainpy>=2.4.6', + DeprecationWarning) def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): """Reset local delay variables. @@ -473,23 +398,14 @@ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): nodes: sequence, dict The nodes to Reset their delay variables. """ - global DynamicalSystem - if DynamicalSystem is None: - from brainpy._src.dynsys import DynamicalSystem - - # reset delays - if nodes is None: - nodes = self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values() - elif isinstance(nodes, dict): - nodes = nodes.values() - for node in nodes: - for name in node.local_delay_vars: - delay = global_delay_data[name][0] - target = global_delay_data[name][1] - delay.reset(target.value) + warnings.warn('.reset_local_delays() has been removed since brainpy>=2.4.6', + DeprecationWarning) def get_delay_var(self, name): - return global_delay_data[name] + _delay_identifier, _init_delay_by_return = _get_delay_tool() + _delay_identifier = _delay_identifier + name + delay_cls = self.get_aft_update(_delay_identifier) + return delay_cls class SupportInputProj(MixIn): @@ -599,10 +515,12 @@ def unbind_cond(self): class SupportSTDP(MixIn): """Support synaptic plasticity by modifying the weights. """ - def update_STDP(self, - dW: Union[bm.Array, jax.Array], - constraints: Optional[Callable] = None, - ): + + def update_STDP( + self, + dW: Union[bm.Array, jax.Array], + constraints: Optional[Callable] = None, + ): raise NotImplementedError diff --git a/brainpy/mixin.py b/brainpy/mixin.py index ab3c3cd37..232fd744e 100644 --- a/brainpy/mixin.py +++ b/brainpy/mixin.py @@ -1,14 +1,14 @@ from brainpy._src.mixin import ( MixIn as MixIn, - SupportInputProj as SupportInputProj, AlignPost as AlignPost, - SupportAutoDelay as SupportAutoDelay, ParamDesc as ParamDesc, ParamDescInit as ParamDescInit, BindCondData as BindCondData, Container as Container, TreeNode as TreeNode, JointType as JointType, - SupportSTDP as SupportPlasticity, + SupportAutoDelay as SupportAutoDelay, + SupportInputProj as SupportInputProj, + SupportSTDP as SupportSTDP, ) diff --git a/tests/simulation/test_net_rate_FHN.py b/tests/simulation/test_net_rate_FHN.py index 157eeb78a..de90794d7 100644 --- a/tests/simulation/test_net_rate_FHN.py +++ b/tests/simulation/test_net_rate_FHN.py @@ -9,6 +9,7 @@ show = False bm.set_platform('cpu') + class Network(bp.Network): def __init__(self, signal_speed=20.): super(Network, self).__init__() From cf4267d92ee76c6b454d8d4dfe0eb980fdd31159 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 24 Oct 2023 19:59:10 +0800 Subject: [PATCH 3/5] [delay] generalize delay register 1. in the brainpy 2.4.x version, the delay is registered with ``return_info()`` function. From now on, the delay can be registered as ``.register_local_delay('spike', 'pre', 1.)``. 2. this generalizes the delay registration APIs so it can be used to register of any variables --- brainpy/_src/dynsys.py | 71 ++++++++++++++++++++------------ brainpy/_src/mixin.py | 37 +---------------- brainpy/_src/tests/test_mixin.py | 4 +- 3 files changed, 48 insertions(+), 64 deletions(-) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index e99610829..d85c16d9c 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -3,6 +3,7 @@ import collections import inspect import warnings +import numbers from typing import Union, Dict, Callable, Sequence, Optional, Any import numpy as np @@ -11,7 +12,7 @@ from brainpy._src.context import share from brainpy._src.deprecations import _update_deprecate_msg from brainpy._src.initialize import parameter, variable_ -from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister +from brainpy._src.mixin import SupportAutoDelay, Container, SupportInputProj, DelayRegister, _get_delay_tool from brainpy.errors import NoImplementationError, UnsupportedError, APIChangedError from brainpy.types import ArrayType, Shape @@ -88,14 +89,10 @@ def __init__( f'which are parents of {self.supported_modes}, ' f'but we got {self.mode}.') - # Attribute for "ReceiveInputProj" + # Attribute for "SupportInputProj" + # each instance of "SupportInputProj" should have a "cur_inputs" attribute self.cur_inputs = bm.node_dict() - # local delay variables: - # Compatible for ``DelayRegister`` - # TODO: will be deprecated in the future - self.local_delay_vars: Dict = bm.node_dict() - # the before- / after-updates used for computing # added after the version of 2.4.3 self.before_updates: Dict[str, Callable] = bm.node_dict() @@ -136,18 +133,6 @@ def has_aft_update(self, key: Any): """Whether this node has the after update of the given ``key``.""" return key in self.after_updates - def reset_bef_updates(self, *args, **kwargs): - """Reset all before updates.""" - for node in self.before_updates.values(): - if isinstance(node, DynamicalSystem): - node.reset(*args, **kwargs) - - def reset_aft_updates(self, *args, **kwargs): - """Reset all after updates.""" - for node in self.after_updates.values(): - if isinstance(node, DynamicalSystem): - node.reset(*args, **kwargs) - def update(self, *args, **kwargs): """The function to specify the updating rule. """ @@ -240,6 +225,44 @@ def mode(self, value): f'but we got {type(value)}: {value}') self._mode = value + def register_local_delay( + self, + var_name: str, + delay_name: str, + delay: Union[numbers.Number, ArrayType] = None, + ): + """Register local relay at the given delay time. + + Args: + var_name: str. The name of the delay target variable. + delay_name: str. The name of the current delay data. + delay: The delay time. + """ + delay_identifier, init_delay_by_return = _get_delay_tool() + delay_identifier = delay_identifier + var_name + try: + target = getattr(self, var_name) + except AttributeError: + raise AttributeError(f'This node {self} does not has attribute of "{var_name}".') + if not self.has_aft_update(delay_identifier): + self.add_aft_update(delay_identifier, init_delay_by_return(target)) + delay_cls = self.get_aft_update(delay_identifier) + delay_cls.register_entry(delay_name, delay) + + def get_local_delay(self, var_name, delay_name): + """Get the delay at the given identifier (`name`). + + Args: + var_name: The name of the target delay variable. + delay_name: The identifier of the delay. + + Returns: + The delayed data at the given delay position. + """ + delay_identifier, init_delay_by_return = _get_delay_tool() + delay_identifier = delay_identifier + var_name + return self.get_aft_update(delay_identifier).at(delay_name) + def _compatible_update(self, *args, **kwargs): update_fun = super().__getattribute__('update') update_args = tuple(inspect.signature(update_fun).parameters.values()) @@ -324,11 +347,8 @@ def _compatible_update(self, *args, **kwargs): return ret return update_fun(*args, **kwargs) - # def __getattr__(self, item): - # if item == 'update': - # return self._compatible_update # update function compatible with previous ``update()`` function - # else: - # return object.__getattribute__(self, item) + def _get_update_fun(self): + return object.__getattribute__(self, 'update') def __getattribute__(self, item): if item == 'update': @@ -336,9 +356,6 @@ def __getattribute__(self, item): else: return super().__getattribute__(item) - def _get_update_fun(self): - return object.__getattribute__(self, 'update') - def __repr__(self): return f'{self.name}(mode={self.mode})' diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 75249692a..39c3ace6b 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -1,3 +1,5 @@ +# -*- coding: utf-8 -*- + import numbers import sys import warnings @@ -284,41 +286,6 @@ def check_hierarchy(self, root, leaf): class DelayRegister(MixIn): - def register_delay_at( - self, - name: str, - delay: Union[numbers.Number, ArrayType] = None, - target: Optional[bm.Variable] = None - ): - """Register relay at the given delay time. - - Args: - name: str. The identifier of the delay. - delay: The delay time. - target: Variable. The delay target variable. - """ - delay_identifier, init_delay_by_return = _get_delay_tool() - DynamicalSystem = _get_dynsys() - assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' - if not self.has_aft_update(delay_identifier): - if target is None: - assert isinstance(self, SupportAutoDelay), f'self must be an instance of {SupportAutoDelay.__name__}' - target = self.return_info() - self.add_aft_update(delay_identifier, init_delay_by_return(target)) - delay_cls = self.get_aft_update(delay_identifier) - delay_cls.register_entry(name, delay) - - def get_delay_at(self, name): - """Get the delay at the given identifier (`name`). - - Args: - name: The identifier of the delay. - - Returns: - The delay data. - """ - return self.get_aft_update(delay_identifier).at(name) - def register_delay( self, identifier: str, diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py index 8a1aece7c..be8eaade6 100644 --- a/brainpy/_src/tests/test_mixin.py +++ b/brainpy/_src/tests/test_mixin.py @@ -42,8 +42,8 @@ class TestDelayRegister(unittest.TestCase): def test2(self): bp.share.save(i=0) lif = bp.dyn.Lif(10) - lif.register_delay_at('a', 10.) - data = lif.get_delay_at('a') + lif.register_local_delay('a', 10.) + data = lif.get_local_delay('a') self.assertTrue(bm.allclose(data, bm.zeros(10))) From 5b97741f13d94525ac18e8cd78de8975ccefcbf9 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 24 Oct 2023 20:08:56 +0800 Subject: [PATCH 4/5] update reset_states --- brainpy/_src/dyn/projections/aligns.py | 38 +++++++++++++++++++++- brainpy/_src/dyn/projections/others.py | 3 ++ brainpy/_src/dyn/projections/plasticity.py | 3 ++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py index 23b907286..c19f45844 100644 --- a/brainpy/_src/dyn/projections/aligns.py +++ b/brainpy/_src/dyn/projections/aligns.py @@ -27,6 +27,9 @@ def update(self, x): else: return x >> self.syn >> self.delay + def reset_state(self, *args, **kwargs): + pass + class _AlignPost(DynamicalSystem): def __init__(self, @@ -39,6 +42,9 @@ def __init__(self, def update(self, *args, **kwargs): self.out.bind_cond(self.syn(*args, **kwargs)) + def reset_state(self, *args, **kwargs): + pass + class _AlignPreMg(DynamicalSystem): def __init__(self, access, syn): @@ -49,6 +55,9 @@ def __init__(self, access, syn): def update(self, *args, **kwargs): return self.syn(self.access()) + def reset_state(self, *args, **kwargs): + pass + def _get_return(return_info): if isinstance(return_info, bm.Variable): @@ -132,6 +141,9 @@ def update(self, x): self.refs['out'].bind_cond(current) return current + def reset_state(self, *args, **kwargs): + pass + class ProjAlignPostMg1(Projection): r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. @@ -224,6 +236,9 @@ def update(self, x): self.refs['syn'].add_current(current) # synapse post current return current + def reset_state(self, *args, **kwargs): + pass + class ProjAlignPostMg2(Projection): """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. @@ -352,6 +367,9 @@ def update(self): self.refs['syn'].add_current(current) # synapse post current return current + def reset_state(self, *args, **kwargs): + pass + class ProjAlignPost1(Projection): """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. @@ -438,6 +456,9 @@ def update(self, x): self.refs['syn'].add_current(current) return current + def reset_state(self, *args, **kwargs): + pass + class ProjAlignPost2(Projection): """Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group. @@ -561,6 +582,9 @@ def update(self): self.refs['out'].bind_cond(g) # synapse post current return g + def reset_state(self, *args, **kwargs): + pass + class ProjAlignPreMg1(Projection): """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. @@ -686,6 +710,9 @@ def update(self, x=None): self.refs['out'].bind_cond(current) return current + def reset_state(self, *args, **kwargs): + pass + class ProjAlignPreMg2(Projection): """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. @@ -814,6 +841,9 @@ def update(self): self.refs['out'].bind_cond(current) return current + def reset_state(self, *args, **kwargs): + pass + class ProjAlignPre1(Projection): """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. @@ -933,6 +963,9 @@ def update(self, x=None): self.refs['out'].bind_cond(current) return current + def reset_state(self, *args, **kwargs): + pass + class ProjAlignPre2(Projection): """Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group. @@ -1052,4 +1085,7 @@ def update(self): spk = self.refs['delay'].at(self.name) g = self.comm(self.syn(spk)) self.refs['out'].bind_cond(g) - return g \ No newline at end of file + return g + + def reset_state(self, *args, **kwargs): + pass diff --git a/brainpy/_src/dyn/projections/others.py b/brainpy/_src/dyn/projections/others.py index 44cdfb043..72a77298f 100644 --- a/brainpy/_src/dyn/projections/others.py +++ b/brainpy/_src/dyn/projections/others.py @@ -54,6 +54,9 @@ def __init__( self.freq = check.is_float(freq, min_bound=0., allow_int=True) self.weight = check.is_float(weight, allow_int=True) + def reset_state(self, *args, **kwargs): + pass + def update(self): p = self.freq * share['dt'] / 1e3 a = self.num_input * p diff --git a/brainpy/_src/dyn/projections/plasticity.py b/brainpy/_src/dyn/projections/plasticity.py index 01f3e7bea..7c176c125 100644 --- a/brainpy/_src/dyn/projections/plasticity.py +++ b/brainpy/_src/dyn/projections/plasticity.py @@ -174,6 +174,9 @@ def __init__( self.A1 = parameter(A1, sizes=self.pre_num) self.A2 = parameter(A2, sizes=self.post_num) + def reset_state(self, *args, **kwargs): + pass + def _init_trace( self, target: DynamicalSystem, From 57b25f602466ccae94f6323a057848b69d5b844f Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 24 Oct 2023 20:32:41 +0800 Subject: [PATCH 5/5] fix tests --- brainpy/_src/tests/test_mixin.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py index be8eaade6..5fbab7b9f 100644 --- a/brainpy/_src/tests/test_mixin.py +++ b/brainpy/_src/tests/test_mixin.py @@ -42,9 +42,12 @@ class TestDelayRegister(unittest.TestCase): def test2(self): bp.share.save(i=0) lif = bp.dyn.Lif(10) - lif.register_local_delay('a', 10.) - data = lif.get_local_delay('a') + lif.register_local_delay('spike', 'a', 10.) + data = lif.get_local_delay('spike', 'a') self.assertTrue(bm.allclose(data, bm.zeros(10))) + with self.assertRaises(AttributeError): + lif.register_local_delay('a', 'a', 10.) +