diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index a8f55afb..33677691 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -7,13 +7,15 @@ import numpy as np from jax import numpy as jnp from jax.interpreters import ad +from jax.experimental.sparse import csr from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (XLACustomOp) +from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm from brainpy._src.math.sparse.utils import csr_to_coo +from brainpy._src.math.defaults import float_ ti = import_taichi() @@ -86,23 +88,26 @@ def raw_event_csrmm_taichi( return [jnp.zeros(result_shape, dtype=data.dtype), ] assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - if transpose: + + # homo -> taichi + # heter -> cusparse + if data.shape[0] != 1: if matrix.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csr_matmat_transpose_bool_homo_p + # change dtype to float + matrix = matrix.astype(float_) + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] + else: + if transpose: + if matrix.dtype == jnp.bool_: + prim = _event_csr_matmat_transpose_homo_p else: - prim = _event_csr_matmat_transpose_bool_heter_p + return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) - else: - if matrix.dtype == jnp.bool_: - if data.shape[0] == 1: + if matrix.dtype == jnp.bool_: prim = _event_csr_matmat_bool_homo_p else: - prim = _event_csr_matmat_bool_heter_p - else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) - return prim(data, + return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) + return prim(data, indices, indptr, matrix, @@ -299,3 +304,7 @@ def _define_op(cpu_kernel, gpu_kernel): # bool no transpose homo _event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo, gpu_kernel=_event_csr_matmat_bool_homo) + +# heter CUSPARSE +_csr_matmat_cusparse_p = csr.csr_matmat_p +register_general_batching(_csr_matmat_cusparse_p) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index d0ea66ca..47c24fa4 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -147,7 +147,6 @@ def raw_csrmm_taichi( # _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, # gpu_kernel=_csr_matmat_heter) - @ti.kernel def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), @@ -203,10 +202,9 @@ def _csr_matmat_transpose( if ad.is_undefined_primal(col_indices) or ad.is_undefined_primal(row_ptr): raise ValueError("Cannot transpose with respect to sparse indices.") assert ad.is_undefined_primal(matrix) - ct_matrix = _csr_matmat_transpose_homo_p(col_indices, row_ptr, ct[0], - shape=shape, - transpose=not transpose, - outs=[jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)]) + ct_matrix = raw_csrmm_taichi(jnp.ones(1), col_indices, row_ptr, ct[0], + shape=shape, + transpose=not transpose) return col_indices, row_ptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix[0])