Skip to content

Commit

Permalink
Optimize automatic differentiation of csrmv
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 6, 2024
1 parent 70b0cb9 commit b3e6fe0
Showing 1 changed file with 37 additions and 2 deletions.
39 changes: 37 additions & 2 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,27 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
j += 32
out[row_i] += r # TODO: warp-level primitive

@ti.kernel
def _event_csr_matvec_dW(values: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
out: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i] != 0.:
for j in range(indptr[i], indptr[i + 1]):
out[j] = values[indices[j]]

@ti.kernel
def _event_csr_matvec_dW_bool(values: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
out: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i]:
for j in range(indptr[i], indptr[i + 1]):
out[j] = values[indices[j]]

def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)
Expand All @@ -466,8 +487,14 @@ def _event_csr_matvec_transpose_taichi(
ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
ct_values = jnp.inner(ct[0], ct_values)
else: # heterogeneous values
row, col = csr_to_coo(indices, indptr)
ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
# row, col = csr_to_coo(indices, indptr)
# ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
if events.dtype == jnp.bool_:
ct_values = _event_csr_matvec_dW_bool(ct[0], indices, indptr, events,
outs=[jax.ShapeDtypeStruct((values.shape[0],), values.dtype)])[0]
else:
ct_values = _event_csr_matvec_dW(ct[0], indices, indptr, events,
outs=[jax.ShapeDtypeStruct((values.shape[0],), values.dtype)])[0]
return ct_values, indices, indptr, events


Expand Down Expand Up @@ -509,3 +536,11 @@ def _define_op(cpu_kernel, gpu_kernel):
# not transpose heter
_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu,
_event_csr_matvec_heter_gpu)

# compute dW
_event_csrmv_dW_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW,
gpu_kernel=_event_csr_matvec_dW)

# compute dW bool
_event_csrmv_dW_bool_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool,
gpu_kernel=_event_csr_matvec_dW_bool)

0 comments on commit b3e6fe0

Please sign in to comment.