Skip to content

Commit

Permalink
Update csr_matvec.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 9, 2024
1 parent 86048ba commit 6cc08f5
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,11 +479,9 @@ def _event_csr_matvec_dW(dy: ti.types.ndarray(),
events: ti.types.ndarray(),
dw: ti.types.ndarray()):
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[j] != 0.:
r += dy[indices[j]]
dw[row_i] = r
if events[indices[j]] != 0.:
dw[j] = dy[row_i]


@ti.kernel
Expand All @@ -493,11 +491,9 @@ def _event_csr_matvec_dW_bool(dy: ti.types.ndarray(),
events: ti.types.ndarray(),
dw: ti.types.ndarray()):
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[j]:
r += dy[indices[j]]
dw[row_i] = r
if events[indices[j]]:
dw[j] = dy[row_i]


def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
Expand Down Expand Up @@ -594,4 +590,4 @@ def _define_op(cpu_kernel, gpu_kernel):

# compute dW bool
_event_csrmv_dW_bool_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool,
gpu_kernel=_event_csr_matvec_dW_bool)
gpu_kernel=_event_csr_matvec_dW_bool)

0 comments on commit 6cc08f5

Please sign in to comment.