Skip to content

Commit

Permalink
Update operator selection strategy for csr matmat
Browse files Browse the repository at this point in the history
homo -> taichi,
heter(CPU) -> taichi, heter(GPU) -> cusparse
  • Loading branch information
Routhleck committed Feb 17, 2024
1 parent c37ce92 commit 9ea9800
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions brainpy/_src/math/sparse/_csr_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import jax
import numpy as np
import brainpy.math as bm
from jax import numpy as jnp
from jax.interpreters import ad
from jax.experimental.sparse import csr
Expand Down Expand Up @@ -47,19 +48,7 @@ def csrmm(
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product product.
"""
data = jnp.atleast_1d(data)
out_shape = shape[1] if transpose else shape[0]
result_shape = (out_shape, matrix.shape[1])
# if the shape of indices is (0,), then we return a zero matrix

if data.shape[0] != 1:
if indices.shape[0] == 0:
return jnp.zeros(result_shape, dtype=data.dtype)
return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)
else:
if indices.shape[0] == 0:
return jnp.zeros(result_shape, dtype=data.dtype)
return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0]
return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0]


def raw_csrmm_taichi(
Expand All @@ -76,14 +65,14 @@ def raw_csrmm_taichi(
indices = as_jax(indices)
indptr = as_jax(indptr)
matrix = as_jax(matrix)
data = jnp.atleast_1d(data)

if matrix.dtype == jnp.bool_:
matrix = as_jax(matrix, dtype=data.dtype)

if data.dtype != matrix.dtype:
raise TypeError('The types of data and vector should be the same. '
f'But we got {data.dtype} != {matrix.dtype}.')
assert data.ndim == indices.ndim == indptr.ndim == 1
assert matrix.ndim == 2

if np.ndim(data) == 1:
Expand All @@ -100,10 +89,19 @@ def raw_csrmm_taichi(
result_shape = (out_shape, matrix.shape[1])

assert matrix.shape[0] == (shape[0] if transpose else shape[1])

if indices.shape[0] == 0:
return [jnp.zeros(result_shape, dtype=data.dtype), ]
# homo -> taichi,
# heter(CPU) -> taichi, heter(GPU) -> cusparse
if data.shape[0] != 1:
return _csr_matmat_heter_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose)
if bm.get_platform() == 'gpu':
return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ]
else:
if transpose:
prim = _csr_matmat_transpose_heter_p
else:
prim = _csr_matmat_heter_p
else:
if transpose:
prim = _csr_matmat_transpose_homo_p
Expand Down Expand Up @@ -290,13 +288,13 @@ def _define_op(cpu_kernel, gpu_kernel):
return prim


# # transpose heter
# _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu,
# gpu_kernel=_csr_matmat_transpose_heter_gpu)
#
# # no transpose heter
# _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu,
# gpu_kernel=_csr_matmat_heter_gpu)
# transpose heter
_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu,
gpu_kernel=_csr_matmat_transpose_heter_gpu)

# no transpose heter
_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu,
gpu_kernel=_csr_matmat_heter_gpu)

# transpose homo
_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu,
Expand All @@ -307,5 +305,5 @@ def _define_op(cpu_kernel, gpu_kernel):
gpu_kernel=_csr_matmat_homo_gpu)

# heter CUSPARSE
_csr_matmat_heter_p = csr.csr_matmat_p
register_general_batching(_csr_matmat_heter_p)
_csr_matmat_cusparse_p = csr.csr_matmat_p
register_general_batching(_csr_matmat_cusparse_p)

0 comments on commit 9ea9800

Please sign in to comment.