Skip to content

Commit

Permalink
Fix delay bugs (#650)
Browse files Browse the repository at this point in the history
* remove `reset_state()`

* make default attribute None

* fix delay bugs

* fix `DynamicalSystem.register_local_delay()` bug

* add the following decorators for enhancing the ``update()`` capability:

 - `brainpy.receive_update_output()`
 - `brainpy.receive_update_input()`
 - `brainpy.not_receive_update_output()`
 - `brainpy.not_receive_update_input()`

* fix
  • Loading branch information
chaoming0625 authored Mar 5, 2024
1 parent 6c36794 commit 7511afd
Show file tree
Hide file tree
Showing 9 changed files with 275 additions and 63 deletions.
5 changes: 4 additions & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@
Sequential as Sequential,
Dynamic as Dynamic, # category
Projection as Projection,
receive_update_input, # decorators
receive_update_output,
not_receive_update_input,
not_receive_update_output,
)
DynamicalSystemNS = DynamicalSystem
Network = DynSysGroup
Expand All @@ -84,7 +88,6 @@
load_state as load_state,
clear_input as clear_input)


# Part: Running #
# --------------- #
from brainpy._src.runners import (DSRunner as DSRunner)
Expand Down
42 changes: 30 additions & 12 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,21 @@
]


delay_identifier = '_*_delay_*_'
delay_identifier = '_*_delay_of_'


def _get_delay(delay_time, delay_step):
if delay_time is None:
if delay_step is None:
return None, None
else:
assert isinstance(delay_step, int), '"delay_step" should be an integer.'
delay_time = delay_step * bm.get_dt()
else:
assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
assert isinstance(delay_time, (int, float))
delay_step = math.ceil(delay_time / bm.get_dt())
return delay_time, delay_step


class Delay(DynamicalSystem, ParamDesc):
Expand Down Expand Up @@ -97,13 +111,15 @@ def __init__(
def register_entry(
self,
entry: str,
delay_time: Optional[Union[float, bm.Array, Callable]],
delay_time: Optional[Union[float, bm.Array, Callable]] = None,
delay_step: Optional[int] = None
) -> 'Delay':
"""Register an entry to access the data.
Args:
entry: str. The entry to access the delay data.
delay_time: The delay time of the entry (can be a float).
delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``.
Returns:
Return the self.
Expand Down Expand Up @@ -237,13 +253,15 @@ def __init__(
def register_entry(
self,
entry: str,
delay_time: Optional[Union[int, float]],
delay_time: Optional[Union[int, float]] = None,
delay_step: Optional[int] = None,
) -> 'Delay':
"""Register an entry to access the data.
Args:
entry: str. The entry to access the delay data.
delay_time: The delay time of the entry (can be a float).
delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
Returns:
Return the self.
Expand All @@ -258,12 +276,7 @@ def register_entry(
assert delay_time.size == 1 and delay_time.ndim == 0
delay_time = delay_time.item()

if delay_time is None:
delay_step = None
delay_time = 0.
else:
assert isinstance(delay_time, (int, float))
delay_step = math.ceil(delay_time / bm.get_dt())
_, delay_step = _get_delay(delay_time, delay_step)

# delay variable
if delay_step is not None:
Expand Down Expand Up @@ -354,24 +367,29 @@ def update(
"""Update delay variable with the new data.
"""
if self.data is not None:
# jax.debug.print('last value == target value {} ', jnp.allclose(latest_value, self.target.value))

# get the latest target value
if latest_value is None:
latest_value = self.target.value

# update the delay data at the rotation index
if self.method == ROTATE_UPDATE:
i = share.load('i')
idx = bm.as_jax((-i - 1) % self.max_length, dtype=jnp.int32)
self.data[idx] = latest_value
idx = bm.as_jax(-i % self.max_length, dtype=jnp.int32)
self.data[jax.lax.stop_gradient(idx)] = latest_value

# update the delay data at the first position
elif self.method == CONCAT_UPDATE:
if self.max_length > 1:
latest_value = bm.expand_dims(latest_value, 0)
self.data.value = bm.concat([latest_value, self.data[1:]], axis=0)
self.data.value = bm.concat([latest_value, self.data[:-1]], axis=0)
else:
self.data[0] = latest_value

else:
raise ValueError(f'Unknown updating method "{self.method}"')

def reset_state(self, batch_size: int = None, **kwargs):
"""Reset the delay data.
"""
Expand Down
25 changes: 9 additions & 16 deletions brainpy/_src/dynold/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,7 @@ def __init__(
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.pre.register_local_delay("spike", self.name, delay_step)

def reset_state(self, batch_size=None):
self.output.reset_state(batch_size)
if self.stp is not None:
self.stp.reset_state(batch_size)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

def update(self, pre_spike=None):
# pre-synaptic spikes
Expand Down Expand Up @@ -232,7 +227,6 @@ class Exponential(TwoEndConn):
method: str
The numerical integration methods.
"""

def __init__(
Expand Down Expand Up @@ -283,17 +277,16 @@ def __init__(
else:
raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".')

# variables
self.g = self.syn.g

# delay
self.pre.register_local_delay("spike", self.name, delay_step)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

def reset_state(self, batch_size=None):
self.syn.reset_state(batch_size)
self.output.reset_state(batch_size)
if self.stp is not None:
self.stp.reset_state(batch_size)
@property
def g(self):
return self.syn.g

@g.setter
def g(self, value):
self.syn.g = value

def update(self, pre_spike=None):
# delays
Expand Down
14 changes: 2 additions & 12 deletions brainpy/_src/dynold/synapses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from brainpy._src.dyn.base import NeuDyn
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import parameter
from brainpy._src.mixin import (ParamDesc, JointType,
SupportAutoDelay, BindCondData, ReturnInfo)
from brainpy._src.mixin import (ParamDesc, JointType, SupportAutoDelay, BindCondData, ReturnInfo)
from brainpy.errors import UnsupportedError
from brainpy.types import ArrayType

Expand Down Expand Up @@ -47,9 +46,6 @@ def isregistered(self, val: bool):
raise ValueError('Must be an instance of bool.')
self._registered = val

def reset_state(self, batch_size=None):
pass

def register_master(self, master: SynConn):
if not isinstance(master, SynConn):
raise TypeError(f'master must be instance of {SynConn.__name__}, but we got {type(master)}')
Expand Down Expand Up @@ -296,7 +292,7 @@ def __init__(
mode=mode)

# delay
self.pre.register_local_delay("spike", self.name, delay_step)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

# synaptic dynamics
self.syn = syn
Expand Down Expand Up @@ -340,11 +336,5 @@ def g_max(self, v):
UserWarning)
self.comm.weight = v

def reset_state(self, *args, **kwargs):
self.syn.reset(*args, **kwargs)
self.comm.reset(*args, **kwargs)
self.output.reset(*args, **kwargs)
if self.stp is not None:
self.stp.reset(*args, **kwargs)


129 changes: 119 additions & 10 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,41 @@ def __init__(

# Attribute for "SupportInputProj"
# each instance of "SupportInputProj" should have a "cur_inputs" attribute
self.current_inputs = bm.node_dict()
self.delta_inputs = bm.node_dict()
self._current_inputs: Optional[Dict[str, Callable]] = None
self._delta_inputs: Optional[Dict[str, Callable]] = None

# the before- / after-updates used for computing
# added after the version of 2.4.3
self.before_updates: Dict[str, Callable] = bm.node_dict()
self.after_updates: Dict[str, Callable] = bm.node_dict()
self._before_updates: Optional[Dict[str, Callable]] = None
self._after_updates: Optional[Dict[str, Callable]] = None

# super initialization
super().__init__(name=name)

@property
def current_inputs(self):
if self._current_inputs is None:
self._current_inputs = bm.node_dict()
return self._current_inputs

@property
def delta_inputs(self):
if self._delta_inputs is None:
self._delta_inputs = bm.node_dict()
return self._delta_inputs

@property
def before_updates(self):
if self._before_updates is None:
self._before_updates = bm.node_dict()
return self._before_updates

@property
def after_updates(self):
if self._after_updates is None:
self._after_updates = bm.node_dict()
return self._after_updates

def add_bef_update(self, key: Any, fun: Callable):
"""Add the before update into this node"""
if key in self.before_updates:
Expand Down Expand Up @@ -220,25 +244,32 @@ def register_local_delay(
self,
var_name: str,
delay_name: str,
delay: Union[numbers.Number, ArrayType] = None,
delay_time: Union[numbers.Number, ArrayType] = None,
delay_step: Union[numbers.Number, ArrayType] = None,
):
"""Register local relay at the given delay time.
Args:
var_name: str. The name of the delay target variable.
delay_name: str. The name of the current delay data.
delay: The delay time.
delay_time: The delay time. Float.
delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``.
"""
delay_identifier, init_delay_by_return = _get_delay_tool()
delay_identifier = delay_identifier + var_name
# check whether the "var_name" has been registered
try:
target = getattr(self, var_name)
except AttributeError:
raise AttributeError(f'This node {self} does not has attribute of "{var_name}".')
if not self.has_aft_update(delay_identifier):
self.add_aft_update(delay_identifier, init_delay_by_return(target))
# add a model to receive the return of the target model
# moreover, the model should not receive the return of the update function
model = not_receive_update_output(init_delay_by_return(target))
# register the model
self.add_aft_update(delay_identifier, model)
delay_cls = self.get_aft_update(delay_identifier)
delay_cls.register_entry(delay_name, delay)
delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step)

def get_local_delay(self, var_name, delay_name):
"""Get the delay at the given identifier (`name`).
Expand Down Expand Up @@ -381,14 +412,20 @@ def __call__(self, *args, **kwargs):

# ``before_updates``
for model in self.before_updates.values():
model()
if hasattr(model, '_receive_update_input'):
model(*args, **kwargs)
else:
model()

# update the model self
ret = self.update(*args, **kwargs)

# ``after_updates``
for model in self.after_updates.values():
model(ret)
if hasattr(model, '_not_receive_update_output'):
model()
else:
model(ret)
return ret

def __rrshift__(self, other):
Expand Down Expand Up @@ -832,3 +869,75 @@ def _slice_to_num(slice_: slice, length: int):
start += step
num += 1
return num


def receive_update_output(cls: object):
"""
The decorator to mark the object (as the after updates) to receive the output of the update function.
That is, the `aft_update` will receive the return of the update function::
ret = model.update(*args, **kwargs)
for fun in model.aft_updates:
fun(ret)
"""
# assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.'
if hasattr(cls, '_not_receive_update_output'):
delattr(cls, '_not_receive_update_output')
return cls


def not_receive_update_output(cls: object):
"""
The decorator to mark the object (as the after updates) to not receive the output of the update function.
That is, the `aft_update` will not receive the return of the update function::
ret = model.update(*args, **kwargs)
for fun in model.aft_updates:
fun()
"""
# assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.'
cls._not_receive_update_output = True
return cls


def receive_update_input(cls: object):
"""
The decorator to mark the object (as the before updates) to receive the input of the update function.
That is, the `bef_update` will receive the input of the update function::
for fun in model.bef_updates:
fun(*args, **kwargs)
model.update(*args, **kwargs)
"""
# assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.'
cls._receive_update_input = True
return cls


def not_receive_update_input(cls: object):
"""
The decorator to mark the object (as the before updates) to not receive the input of the update function.
That is, the `bef_update` will not receive the input of the update function::
for fun in model.bef_updates:
fun()
model.update()
"""
# assert isinstance(cls, DynamicalSystem), 'The input class should be instance of DynamicalSystem.'
if hasattr(cls, '_receive_update_input'):
delattr(cls, '_receive_update_input')
return cls





Loading

0 comments on commit 7511afd

Please sign in to comment.