From 94e90413e2d732a43b6cd70c6d4d1cee9c8d69ef Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 3 Mar 2024 12:43:24 +0800 Subject: [PATCH] Update delayvars.py --- brainpy/_src/math/delayvars.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/delayvars.py b/brainpy/_src/math/delayvars.py index eb8e27c8f..676e4286b 100644 --- a/brainpy/_src/math/delayvars.py +++ b/brainpy/_src/math/delayvars.py @@ -11,7 +11,7 @@ from brainpy import check from brainpy.check import is_float, is_integer, jit_error from brainpy.errors import UnsupportedError -from .compat_numpy import vstack, broadcast_to +from .compat_numpy import broadcast_to, expand_dims, concatenate from .environment import get_dt, get_float from .interoperability import as_jax from .ndarray import ndarray, Array @@ -392,6 +392,7 @@ def reset( dtype=delay_target.dtype), batch_axis=batch_axis) else: + self.data.value self.data._value = jnp.zeros((self.num_delay_step,) + delay_target.shape, dtype=delay_target.dtype) @@ -472,7 +473,7 @@ def update(self, value: Union[numbers.Number, Array, jax.Array] = None): elif self.update_method == CONCAT_UPDATE: if self.num_delay_step >= 2: - self.data.value = vstack([broadcast_to(value, self.data.shape[1:]), self.data[1:]]) + self.data.value = concatenate([expand_dims(value, 0), self.data[:-1]], axis=0) else: self.data[:] = value