Skip to content

Commit

Permalink
[math] Optimize event csr matmat operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 3, 2024
1 parent c259c1b commit 70b0cb9
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
66 changes: 54 additions & 12 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,21 @@ def raw_event_csrmm_taichi(
else:
if transpose:
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_transpose_homo_p
prim = _event_csr_matmat_transpose_bool_homo_p
else:
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
prim = _event_csr_matmat_transpose_homo_p
else:
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_bool_homo_p
else:
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
prim = _event_csr_matmat_homo_p
return prim(data,
indices,
indptr,
matrix,
outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)],
transpose=transpose,
shape=shape)
indices,
indptr,
matrix,
outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)],
transpose=transpose,
shape=shape)


# taichi kernels
Expand Down Expand Up @@ -234,6 +234,36 @@ def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1),
out[row_i, col_k] = r * value


@ti.kernel
def _event_csr_matmat_sum(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(out.shape[0]):
temp_sum = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
for col_k in range(matrix.shape[1]):
if matrix[col_indices[row_j], col_k] != 0.:
temp_sum += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i] = temp_sum


@ti.kernel
def _event_csr_matmat_bool_sum(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(out.shape[0]):
temp_sum = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
for col_k in range(matrix.shape[1]):
if matrix[col_indices[row_j], col_k]:
temp_sum += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i] = temp_sum


def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose)

Expand Down Expand Up @@ -261,8 +291,14 @@ def _event_csr_matmat_transpose(
ct_data = jnp.sum(ct[0] * ct_data)
else: # heter
matrix = jnp.asarray(matrix)
row, col = csr_to_coo(indices, indptr)
ct_data = (ct[0][row] * matrix[col]).sum(1)
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_bool_sum
else:
prim = _event_csr_matmat_sum
ct_data = prim(ct[0], indices, indptr, matrix,
out=jax.ShapeDtypeStruct((data.shape[0], matrix.shape[1]), data.dtype))[0]
# row, col = csr_to_coo(indices, indptr)
# ct_data = (ct[0][row] * matrix[col]).sum(1)
return ct_data, indices, indptr, matrix


Expand Down Expand Up @@ -305,6 +341,12 @@ def _define_op(cpu_kernel, gpu_kernel):
_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo,
gpu_kernel=_event_csr_matmat_bool_homo)

_event_csr_matmat_sum_p = XLACustomOp(cpu_kernel=_event_csr_matmat_sum,
gpu_kernel=_event_csr_matmat_sum)

_event_csr_matmat_bool_sum_p = XLACustomOp(cpu_kernel=_event_csr_matmat_bool_sum,
gpu_kernel=_event_csr_matmat_bool_sum)

# heter CUSPARSE
_csr_matmat_cusparse_p = csr.csr_matmat_p
register_general_batching(_csr_matmat_cusparse_p)
register_general_batching(_csr_matmat_cusparse_p)
3 changes: 2 additions & 1 deletion brainpy/_src/math/event/tests/test_event_csrmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def test_homo_grad(self, transpose, shape, homo_data):
argnums=0)
r1 = dense_f1(homo_data)
r2 = jax.grad(sum_op(bm.event.csrmm))(
bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]),
bm.asarray([homo_data]), indices, indptr, matrix,
shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]),
transpose=transpose)

self.assertTrue(bm.allclose(r1, r2))
Expand Down

0 comments on commit 70b0cb9

Please sign in to comment.