diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index c09c0ff7..b5e21446 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -5,6 +5,7 @@ import jax import numpy as np +import brainpy.math as bm from jax import numpy as jnp from jax.interpreters import ad from jax.experimental.sparse import csr @@ -47,19 +48,7 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product product. """ - data = jnp.atleast_1d(data) - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - # if the shape of indices is (0,), then we return a zero matrix - - if data.shape[0] != 1: - if indices.shape[0] == 0: - return jnp.zeros(result_shape, dtype=data.dtype) - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose) - else: - if indices.shape[0] == 0: - return jnp.zeros(result_shape, dtype=data.dtype) - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] def raw_csrmm_taichi( @@ -76,6 +65,7 @@ def raw_csrmm_taichi( indices = as_jax(indices) indptr = as_jax(indptr) matrix = as_jax(matrix) + data = jnp.atleast_1d(data) if matrix.dtype == jnp.bool_: matrix = as_jax(matrix, dtype=data.dtype) @@ -83,7 +73,6 @@ def raw_csrmm_taichi( if data.dtype != matrix.dtype: raise TypeError('The types of data and vector should be the same. ' f'But we got {data.dtype} != {matrix.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == 1 assert matrix.ndim == 2 if np.ndim(data) == 1: @@ -100,10 +89,19 @@ def raw_csrmm_taichi( result_shape = (out_shape, matrix.shape[1]) assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + if indices.shape[0] == 0: return [jnp.zeros(result_shape, dtype=data.dtype), ] + # homo -> taichi, + # heter(CPU) -> taichi, heter(GPU) -> cusparse if data.shape[0] != 1: - return _csr_matmat_heter_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose) + if bm.get_platform() == 'gpu': + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] + else: + if transpose: + prim = _csr_matmat_transpose_heter_p + else: + prim = _csr_matmat_heter_p else: if transpose: prim = _csr_matmat_transpose_homo_p @@ -290,13 +288,13 @@ def _define_op(cpu_kernel, gpu_kernel): return prim -# # transpose heter -# _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, -# gpu_kernel=_csr_matmat_transpose_heter_gpu) -# -# # no transpose heter -# _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, -# gpu_kernel=_csr_matmat_heter_gpu) +# transpose heter +_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, + gpu_kernel=_csr_matmat_transpose_heter_gpu) + +# no transpose heter +_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, + gpu_kernel=_csr_matmat_heter_gpu) # transpose homo _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, @@ -307,5 +305,5 @@ def _define_op(cpu_kernel, gpu_kernel): gpu_kernel=_csr_matmat_homo_gpu) # heter CUSPARSE -_csr_matmat_heter_p = csr.csr_matmat_p -register_general_batching(_csr_matmat_heter_p) +_csr_matmat_cusparse_p = csr.csr_matmat_p +register_general_batching(_csr_matmat_cusparse_p)