Skip to content

Commit

Permalink
Merge pull request #522 from chaoming0625/updates
Browse files Browse the repository at this point in the history
[delay] rewrite previous delay APIs so that they are compatible with new brainpy version
  • Loading branch information
chaoming0625 authored Oct 24, 2023
2 parents aa90559 + 57b25f6 commit e849a9a
Show file tree
Hide file tree
Showing 11 changed files with 162 additions and 237 deletions.
10 changes: 5 additions & 5 deletions brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def update(
else:
self.data[0] = latest_value

def reset_state(self, batch_size: int = None):
def reset_state(self, batch_size: int = None, **kwargs):
"""Reset the delay data.
"""
# initialize delay data
Expand Down Expand Up @@ -439,7 +439,7 @@ def __init__(
name=name,
mode=mode)

def reset_state(self, batch_size: int = None):
def reset_state(self, batch_size: int = None, **kwargs):
"""Reset the delay data.
"""
self.target.value = variable_(self.target_init, self.target.size_without_batch, batch_size)
Expand Down Expand Up @@ -476,9 +476,9 @@ def reset_state(self, *args, **kwargs):
pass


def init_delay_by_return(info: Union[bm.Variable, ReturnInfo]) -> Delay:
def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_data=None) -> Delay:
if isinstance(info, bm.Variable):
return VarDelay(info)
return VarDelay(info, init=initial_delay_data)

elif isinstance(info, ReturnInfo):
# batch size
Expand Down Expand Up @@ -510,6 +510,6 @@ def init_delay_by_return(info: Union[bm.Variable, ReturnInfo]) -> Delay:

# variable
target = bm.Variable(init, batch_axis=batch_axis, axis_names=info.axis_names)
return DataDelay(target, data_init=info.data)
return DataDelay(target, data_init=info.data, init=initial_delay_data)
else:
raise TypeError
38 changes: 37 additions & 1 deletion brainpy/_src/dyn/projections/aligns.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def update(self, x):
else:
return x >> self.syn >> self.delay

def reset_state(self, *args, **kwargs):
pass


class _AlignPost(DynamicalSystem):
def __init__(self,
Expand All @@ -39,6 +42,9 @@ def __init__(self,
def update(self, *args, **kwargs):
self.out.bind_cond(self.syn(*args, **kwargs))

def reset_state(self, *args, **kwargs):
pass


class _AlignPreMg(DynamicalSystem):
def __init__(self, access, syn):
Expand All @@ -49,6 +55,9 @@ def __init__(self, access, syn):
def update(self, *args, **kwargs):
return self.syn(self.access())

def reset_state(self, *args, **kwargs):
pass


def _get_return(return_info):
if isinstance(return_info, bm.Variable):
Expand Down Expand Up @@ -132,6 +141,9 @@ def update(self, x):
self.refs['out'].bind_cond(current)
return current

def reset_state(self, *args, **kwargs):
pass


class ProjAlignPostMg1(Projection):
r"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
Expand Down Expand Up @@ -224,6 +236,9 @@ def update(self, x):
self.refs['syn'].add_current(current) # synapse post current
return current

def reset_state(self, *args, **kwargs):
pass


class ProjAlignPostMg2(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
Expand Down Expand Up @@ -352,6 +367,9 @@ def update(self):
self.refs['syn'].add_current(current) # synapse post current
return current

def reset_state(self, *args, **kwargs):
pass


class ProjAlignPost1(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
Expand Down Expand Up @@ -438,6 +456,9 @@ def update(self, x):
self.refs['syn'].add_current(current)
return current

def reset_state(self, *args, **kwargs):
pass


class ProjAlignPost2(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of postsynaptic neuron group.
Expand Down Expand Up @@ -561,6 +582,9 @@ def update(self):
self.refs['out'].bind_cond(g) # synapse post current
return g

def reset_state(self, *args, **kwargs):
pass


class ProjAlignPreMg1(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
Expand Down Expand Up @@ -686,6 +710,9 @@ def update(self, x=None):
self.refs['out'].bind_cond(current)
return current

def reset_state(self, *args, **kwargs):
pass


class ProjAlignPreMg2(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
Expand Down Expand Up @@ -814,6 +841,9 @@ def update(self):
self.refs['out'].bind_cond(current)
return current

def reset_state(self, *args, **kwargs):
pass


class ProjAlignPre1(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
Expand Down Expand Up @@ -933,6 +963,9 @@ def update(self, x=None):
self.refs['out'].bind_cond(current)
return current

def reset_state(self, *args, **kwargs):
pass


class ProjAlignPre2(Projection):
"""Synaptic projection which defines the synaptic computation with the dimension of presynaptic neuron group.
Expand Down Expand Up @@ -1052,4 +1085,7 @@ def update(self):
spk = self.refs['delay'].at(self.name)
g = self.comm(self.syn(spk))
self.refs['out'].bind_cond(g)
return g
return g

def reset_state(self, *args, **kwargs):
pass
3 changes: 3 additions & 0 deletions brainpy/_src/dyn/projections/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def __init__(
self.freq = check.is_float(freq, min_bound=0., allow_int=True)
self.weight = check.is_float(weight, allow_int=True)

def reset_state(self, *args, **kwargs):
pass

def update(self):
p = self.freq * share['dt'] / 1e3
a = self.num_input * p
Expand Down
3 changes: 3 additions & 0 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ def __init__(
self.A1 = parameter(A1, sizes=self.pre_num)
self.A2 = parameter(A2, sizes=self.post_num)

def reset_state(self, *args, **kwargs):
pass

def _init_trace(
self,
target: DynamicalSystem,
Expand Down
12 changes: 6 additions & 6 deletions brainpy/_src/dyn/synapses/delay_couplings.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def __init__(
def update(self):
# delays
axis = self.coupling_var1.ndim
delay_var: bm.LengthDelay = self.get_delay_var(f'delay_{id(self.delay_var)}')[0]
delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}')
if self.delay_steps is None:
diffusive = (jnp.expand_dims(self.coupling_var1.value, axis=axis) -
jnp.expand_dims(self.coupling_var2.value, axis=axis - 1))
Expand All @@ -201,13 +201,13 @@ def update(self):
indices = (slice(None, None, None), jnp.arange(self.coupling_var1.size),)
else:
indices = (jnp.arange(self.coupling_var1.size),)
f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (..., pre.num)
f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (..., pre.num)
delays = f(self.delay_steps) # (..., post.num, pre.num)
diffusive = (jnp.moveaxis(bm.as_jax(delays), axis - 1, axis) -
jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num)
diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1)
elif self.delay_type == 'int':
delayed_data = delay_var(self.delay_steps) # (..., pre.num)
delayed_data = delay_var.retrieve(self.delay_steps) # (..., pre.num)
diffusive = (jnp.expand_dims(delayed_data, axis=axis) -
jnp.expand_dims(self.coupling_var2.value, axis=axis - 1)) # (..., pre.num, post.num)
diffusive = (self.conn_mat * diffusive).sum(axis=axis - 1)
Expand Down Expand Up @@ -276,19 +276,19 @@ def __init__(
def update(self):
# delay function
axis = self.coupling_var.ndim
delay_var: bm.LengthDelay = self.get_delay_var(f'delay_{id(self.delay_var)}')[0]
delay_var = self.get_delay_var(f'delay_{id(self.delay_var)}')
if self.delay_steps is None:
additive = self.coupling_var @ self.conn_mat
elif self.delay_type == 'array':
if isinstance(self.mode, bm.TrainingMode):
indices = (slice(None, None, None), jnp.arange(self.coupling_var.size),)
else:
indices = (jnp.arange(self.coupling_var.size),)
f = vmap(lambda steps: delay_var(steps, *indices), in_axes=1) # (.., pre.num,)
f = vmap(lambda steps: delay_var.retrieve(steps, *indices), in_axes=1) # (.., pre.num,)
delays = f(self.delay_steps) # (..., post.num, pre.num)
additive = (self.conn_mat * jnp.moveaxis(delays, axis - 1, axis)).sum(axis=axis - 1)
elif self.delay_type == 'int':
delayed_var = delay_var(self.delay_steps) # (..., pre.num)
delayed_var = delay_var.retrieve(self.delay_steps) # (..., pre.num)
additive = delayed_var @ self.conn_mat
else:
raise ValueError
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dynold/synapses/abstract_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def reset_state(self, batch_size=None):
def update(self, pre_spike=None):
# pre-synaptic spikes
if pre_spike is None:
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", delay_step=self.delay_step)
pre_spike = self.get_delay_data(f"{self.pre.name}.spike", self.delay_step)
pre_spike = bm.as_jax(pre_spike)
if self.stop_spike_gradient:
pre_spike = jax.lax.stop_gradient(pre_spike)
Expand Down
Loading

0 comments on commit e849a9a

Please sign in to comment.