Skip to content

Commit

Permalink
Update delayvars.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 3, 2024
1 parent a478cce commit 94e9041
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions brainpy/_src/math/delayvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from brainpy import check
from brainpy.check import is_float, is_integer, jit_error
from brainpy.errors import UnsupportedError
from .compat_numpy import vstack, broadcast_to
from .compat_numpy import broadcast_to, expand_dims, concatenate
from .environment import get_dt, get_float
from .interoperability import as_jax
from .ndarray import ndarray, Array
Expand Down Expand Up @@ -392,6 +392,7 @@ def reset(
dtype=delay_target.dtype),
batch_axis=batch_axis)
else:
self.data.value
self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape,
dtype=delay_target.dtype)

Expand Down Expand Up @@ -472,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None):

elif self.update_method == CONCAT_UPDATE:
if self.num_delay_step >= 2:
self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]])
self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0)
else:
self.data[:] = value

Expand Down

0 comments on commit 94e9041

Please sign in to comment.