From 70b0cb97afc4239b5f34db3a40e8e5df518283c7 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 3 Jun 2024 11:10:13 +0800 Subject: [PATCH] [math] Optimize event csr matmat operators --- brainpy/_src/math/event/csr_matmat.py | 66 +++++++++++++++---- .../_src/math/event/tests/test_event_csrmm.py | 3 +- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index 33677691..b8b3a806 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -99,21 +99,21 @@ def raw_event_csrmm_taichi( else: if transpose: if matrix.dtype == jnp.bool_: - prim = _event_csr_matmat_transpose_homo_p + prim = _event_csr_matmat_transpose_bool_homo_p else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) + prim = _event_csr_matmat_transpose_homo_p else: if matrix.dtype == jnp.bool_: prim = _event_csr_matmat_bool_homo_p else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) + prim = _event_csr_matmat_homo_p return prim(data, - indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], - transpose=transpose, - shape=shape) + indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], + transpose=transpose, + shape=shape) # taichi kernels @@ -234,6 +234,36 @@ def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1), out[row_i, col_k] = r * value +@ti.kernel +def _event_csr_matmat_sum(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=1)): + for row_i in ti.ndrange(out.shape[0]): + temp_sum = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + for col_k in range(matrix.shape[1]): + if matrix[col_indices[row_j], col_k] != 0.: + temp_sum += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i] = temp_sum + + +@ti.kernel +def _event_csr_matmat_bool_sum(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=1)): + for row_i in ti.ndrange(out.shape[0]): + temp_sum = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + for col_k in range(matrix.shape[1]): + if matrix[col_indices[row_j], col_k]: + temp_sum += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i] = temp_sum + + def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) @@ -261,8 +291,14 @@ def _event_csr_matmat_transpose( ct_data = jnp.sum(ct[0] * ct_data) else: # heter matrix = jnp.asarray(matrix) - row, col = csr_to_coo(indices, indptr) - ct_data = (ct[0][row] * matrix[col]).sum(1) + if matrix.dtype == jnp.bool_: + prim = _event_csr_matmat_bool_sum + else: + prim = _event_csr_matmat_sum + ct_data = prim(ct[0], indices, indptr, matrix, + out=jax.ShapeDtypeStruct((data.shape[0], matrix.shape[1]), data.dtype))[0] + # row, col = csr_to_coo(indices, indptr) + # ct_data = (ct[0][row] * matrix[col]).sum(1) return ct_data, indices, indptr, matrix @@ -305,6 +341,12 @@ def _define_op(cpu_kernel, gpu_kernel): _event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo, gpu_kernel=_event_csr_matmat_bool_homo) +_event_csr_matmat_sum_p = XLACustomOp(cpu_kernel=_event_csr_matmat_sum, + gpu_kernel=_event_csr_matmat_sum) + +_event_csr_matmat_bool_sum_p = XLACustomOp(cpu_kernel=_event_csr_matmat_bool_sum, + gpu_kernel=_event_csr_matmat_bool_sum) + # heter CUSPARSE _csr_matmat_cusparse_p = csr.csr_matmat_p -register_general_batching(_csr_matmat_cusparse_p) \ No newline at end of file +register_general_batching(_csr_matmat_cusparse_p) diff --git a/brainpy/_src/math/event/tests/test_event_csrmm.py b/brainpy/_src/math/event/tests/test_event_csrmm.py index 12a35ef3..17d02632 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmm.py +++ b/brainpy/_src/math/event/tests/test_event_csrmm.py @@ -140,7 +140,8 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(bm.event.csrmm))( - bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + bm.asarray([homo_data]), indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r1, r2))