diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py index dfaeba9a..dbfc2c2f 100644 --- a/brainpy/_src/math/event/csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -479,11 +479,9 @@ def _event_csr_matvec_dW(dy: ti.types.ndarray(), events: ti.types.ndarray(), dw: ti.types.ndarray()): for row_i in range(indptr.shape[0] - 1): - r = 0. for j in range(indptr[row_i], indptr[row_i + 1]): - if events[j] != 0.: - r += dy[indices[j]] - dw[row_i] = r + if events[indices[j]] != 0.: + dw[j] = dy[row_i] @ti.kernel @@ -493,11 +491,9 @@ def _event_csr_matvec_dW_bool(dy: ti.types.ndarray(), events: ti.types.ndarray(), dw: ti.types.ndarray()): for row_i in range(indptr.shape[0] - 1): - r = 0. for j in range(indptr[row_i], indptr[row_i + 1]): - if events[j]: - r += dy[indices[j]] - dw[row_i] = r + if events[indices[j]]: + dw[j] = dy[row_i] def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): @@ -594,4 +590,4 @@ def _define_op(cpu_kernel, gpu_kernel): # compute dW bool _event_csrmv_dW_bool_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool, - gpu_kernel=_event_csr_matvec_dW_bool) \ No newline at end of file + gpu_kernel=_event_csr_matvec_dW_bool)