Skip to content

Commit

Permalink
Update csr_matmat.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 9, 2024
1 parent bbb57e4 commit 2c5c24c
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,33 +234,33 @@ def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1),


@ti.kernel
def _event_csr_matmat_sum(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(out.shape[0]):
def _event_csr_matmat_dw(dy: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
dw: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(dw.shape[0]):
temp_sum = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
for col_k in range(matrix.shape[1]):
if matrix[col_indices[row_j], col_k] != 0.:
temp_sum += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i] = temp_sum
temp_sum += dy[row_j] * matrix[col_indices[row_j], col_k]
dw[row_i] = temp_sum


@ti.kernel
def _event_csr_matmat_bool_sum(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(out.shape[0]):
def _event_csr_matmat_bool_dw(dy: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
dw: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(dw.shape[0]):
temp_sum = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
for col_k in range(matrix.shape[1]):
if matrix[col_indices[row_j], col_k]:
temp_sum += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i] = temp_sum
temp_sum += dy[row_j] * matrix[col_indices[row_j], col_k]
dw[row_i] = temp_sum


def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
Expand Down Expand Up @@ -291,13 +291,11 @@ def _event_csr_matmat_transpose(
else: # heter
matrix = jnp.asarray(matrix)
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_bool_sum
prim = _event_csr_matmat_bool_dw
else:
prim = _event_csr_matmat_sum
prim = _event_csr_matmat_dw
ct_data = prim(ct[0], indices, indptr, matrix,
out=jax.ShapeDtypeStruct((data.aval.shape[0], matrix.shape[1]), data.aval.dtype))[0]
# row, col = csr_to_coo(indices, indptr)
# ct_data = (ct[0][row] * matrix[col]).sum(1)
return ct_data, indices, indptr, matrix


Expand Down Expand Up @@ -340,11 +338,11 @@ def _define_op(cpu_kernel, gpu_kernel):
_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo,
gpu_kernel=_event_csr_matmat_bool_homo)

_event_csr_matmat_sum_p = XLACustomOp(cpu_kernel=_event_csr_matmat_sum,
gpu_kernel=_event_csr_matmat_sum)
_event_csr_matmat_dw_p = XLACustomOp(cpu_kernel=_event_csr_matmat_dw,
gpu_kernel=_event_csr_matmat_dw)

_event_csr_matmat_bool_sum_p = XLACustomOp(cpu_kernel=_event_csr_matmat_bool_sum,
gpu_kernel=_event_csr_matmat_bool_sum)
_event_csr_matmat_bool_dw_p = XLACustomOp(cpu_kernel=_event_csr_matmat_bool_dw,
gpu_kernel=_event_csr_matmat_bool_dw)

# heter CUSPARSE
_csr_matmat_cusparse_p = csr.csr_matmat_p
Expand Down

0 comments on commit 2c5c24c

Please sign in to comment.