Skip to content

Commit

Permalink
Update csr_matvec.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 8, 2024
1 parent 4596c5e commit 86048ba
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,51 +449,55 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),


@ti.kernel
def _event_csr_matvec_dW_transpose(values: ti.types.ndarray(),
def _event_csr_matvec_dW_transpose(dy: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
out: ti.types.ndarray()):
dw: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i] != 0.:
for j in range(indptr[i], indptr[i + 1]):
out[j] = values[indices[j]]
dw[j] = dy[indices[j]]


@ti.kernel
def _event_csr_matvec_dW_bool_transpose(values: ti.types.ndarray(),
def _event_csr_matvec_dW_bool_transpose(dy: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
out: ti.types.ndarray()):
dw: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i]:
for j in range(indptr[i], indptr[i + 1]):
out[j] = values[indices[j]]
dw[j] = dy[indices[j]]


@ti.kernel
def _event_csr_matvec_dW(values: ti.types.ndarray(),
def _event_csr_matvec_dW(dy: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
out: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i] != 0.:
for j in range(indptr[i], indptr[i + 1]):
out[indices[j]] += values[j] * events[i]
dw: ti.types.ndarray()):
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[j] != 0.:
r += dy[indices[j]]
dw[row_i] = r


@ti.kernel
def _event_csr_matvec_dW_bool(values: ti.types.ndarray(),
def _event_csr_matvec_dW_bool(dy: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
out: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i]:
for j in range(indptr[i], indptr[i + 1]):
out[indices[j]] += values[j]
dw: ti.types.ndarray()):
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[j]:
r += dy[indices[j]]
dw[row_i] = r


def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
Expand Down

0 comments on commit 86048ba

Please sign in to comment.