Skip to content

Commit

Permalink
fix DynamicalSystem.register_local_delay() bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 5, 2024
1 parent 226b7f2 commit e92fcac
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 16 deletions.
36 changes: 27 additions & 9 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@
delay_identifier = '_*_delay_of_'


def _get_delay(delay_time, delay_step):
if delay_time is None:
if delay_step is None:
return None, None
else:
assert isinstance(delay_step, int), '"delay_step" should be an integer.'
delay_time = delay_step * bm.get_dt()
else:
assert delay_step is None, '"delay_step" should be None if "delay_time" is given.'
assert isinstance(delay_time, (int, float))
delay_step = math.ceil(delay_time / bm.get_dt())
return delay_time, delay_step


class Delay(DynamicalSystem, ParamDesc):
"""Base class for delay variables.
Expand Down Expand Up @@ -97,13 +111,15 @@ def __init__(
def register_entry(
self,
entry: str,
delay_time: Optional[Union[float, bm.Array, Callable]],
delay_time: Optional[Union[float, bm.Array, Callable]] = None,
delay_step: Optional[int] = None
) -> 'Delay':
"""Register an entry to access the data.
Args:
entry: str. The entry to access the delay data.
delay_time: The delay time of the entry (can be a float).
delay_step: The delay step of the entry (must be an int). ``delay_step = delay_time / dt``.
Returns:
Return the self.
Expand Down Expand Up @@ -237,13 +253,15 @@ def __init__(
def register_entry(
self,
entry: str,
delay_time: Optional[Union[int, float]],
delay_time: Optional[Union[int, float]] = None,
delay_step: Optional[int] = None,
) -> 'Delay':
"""Register an entry to access the data.
Args:
entry: str. The entry to access the delay data.
delay_time: The delay time of the entry (can be a float).
delay_step: The delay step of the entry (must be an int). ``delat_step = delay_time / dt``.
Returns:
Return the self.
Expand All @@ -258,12 +276,7 @@ def register_entry(
assert delay_time.size == 1 and delay_time.ndim == 0
delay_time = delay_time.item()

if delay_time is None:
delay_step = None
delay_time = 0.
else:
assert isinstance(delay_time, (int, float))
delay_step = math.ceil(delay_time / bm.get_dt())
_, delay_step = _get_delay(delay_time, delay_step)

# delay variable
if delay_step is not None:
Expand Down Expand Up @@ -354,14 +367,16 @@ def update(
"""Update delay variable with the new data.
"""
if self.data is not None:
# jax.debug.print('last value == target value {} ', jnp.allclose(latest_value, self.target.value))

# get the latest target value
if latest_value is None:
latest_value = self.target.value

# update the delay data at the rotation index
if self.method == ROTATE_UPDATE:
i = share.load('i')
idx = bm.as_jax((-i ) % self.max_length, dtype=jnp.int32)
idx = bm.as_jax(-i % self.max_length, dtype=jnp.int32)
self.data[jax.lax.stop_gradient(idx)] = latest_value

# update the delay data at the first position
Expand All @@ -372,6 +387,9 @@ def update(
else:
self.data[0] = latest_value

else:
raise ValueError(f'Unknown updating method "{self.method}"')

def reset_state(self, batch_size: int = None, **kwargs):
"""Reset the delay data.
"""
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dynold/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def __init__(
self.g_max, self.conn_mask = self._init_weights(g_max, comp_method=comp_method, sparse_data='csr')

# register delay
self.pre.register_local_delay("spike", self.name, delay_step)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

def update(self, pre_spike=None):
# pre-synaptic spikes
Expand Down Expand Up @@ -278,7 +278,7 @@ def __init__(
raise ValueError(f'Does not support {comp_method}, only "sparse" or "dense".')

# delay
self.pre.register_local_delay("spike", self.name, delay_step)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

@property
def g(self):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dynold/synapses/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def __init__(
mode=mode)

# delay
self.pre.register_local_delay("spike", self.name, delay_step)
self.pre.register_local_delay("spike", self.name, delay_step=delay_step)

# synaptic dynamics
self.syn = syn
Expand Down
8 changes: 5 additions & 3 deletions brainpy/_src/dynsys.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,14 +244,16 @@ def register_local_delay(
self,
var_name: str,
delay_name: str,
delay: Union[numbers.Number, ArrayType] = None,
delay_time: Union[numbers.Number, ArrayType] = None,
delay_step: Union[numbers.Number, ArrayType] = None,
):
"""Register local relay at the given delay time.
Args:
var_name: str. The name of the delay target variable.
delay_name: str. The name of the current delay data.
delay: The delay time.
delay_time: The delay time. Float.
delay_step: The delay step. Int. ``delay_step`` and ``delay_time`` are exclusive. ``delay_step = delay_time / dt``.
"""
delay_identifier, init_delay_by_return = _get_delay_tool()
delay_identifier = delay_identifier + var_name
Expand All @@ -262,7 +264,7 @@ def register_local_delay(
if not self.has_aft_update(delay_identifier):
self.add_aft_update(delay_identifier, init_delay_by_return(target))
delay_cls = self.get_aft_update(delay_identifier)
delay_cls.register_entry(delay_name, delay)
delay_cls.register_entry(delay_name, delay_time=delay_time, delay_step=delay_step)

def get_local_delay(self, var_name, delay_name):
"""Get the delay at the given identifier (`name`).
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/tests/test_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TestDelayRegister(unittest.TestCase):
def test2(self):
bp.share.save(i=0)
lif = bp.dyn.Lif(10)
lif.register_local_delay('spike', 'a', 10.)
lif.register_local_delay('spike', 'a', delay_time=10.)
data = lif.get_local_delay('spike', 'a')
self.assertTrue(bm.allclose(data, bm.zeros(10)))

Expand Down

0 comments on commit e92fcac

Please sign in to comment.