Skip to content

Commit

Permalink
accelerate csrmm homo
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 8, 2024
1 parent 116df57 commit c298f48
Showing 1 changed file with 122 additions and 136 deletions.
258 changes: 122 additions & 136 deletions brainpy/_src/math/sparse/csr_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (XLACustomOp, register_general_batching)
from brainpy._src.math.sparse.utils import csr_to_coo
from brainpy._src.math.op_register import (XLACustomOp)
from brainpy.errors import PackageMissingError

ti = import_taichi()
ti = import_taichi(error_if_not_found=False)

__all__ = [
'csrmm',
Expand Down Expand Up @@ -98,144 +98,130 @@ def raw_csrmm_taichi(
if data.shape[0] != 1:
return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ]
else:
if ti is None:
raise PackageMissingError.by_purpose('taichi', 'customzied sparse matrix multiplication')
if transpose:
prim = _csr_matmat_transpose_homo_p
else:
prim = _csr_matmat_homo_p
return prim(data,
indices,
indptr,
matrix,
outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)],
transpose=transpose,
shape=shape)
r = prim(indices,
indptr,
matrix,
outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)],
transpose=transpose,
shape=shape)
return [r[0] * data]


# taichi kernels

@ti.kernel
def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i in range(row_ptr.shape[0] - 1):
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
col = col_indices[i]
for j in range(out.shape[1]):
out[col, j] += values[row_i] * matrix[row_i, j]


@ti.kernel
def _csr_matmat_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += values[j] * matrix[col_indices[j], col_k]
out[row_i, col_k] = r


@ti.kernel
def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
ti.loop_config(serialize=True)
for row_i in range(row_ptr.shape[0] - 1):
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
col = col_indices[i]
for j in range(out.shape[1]):
out[col, j] += value * matrix[row_i, j]


@ti.kernel
def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i in range(row_ptr.shape[0] - 1):
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
col = col_indices[i]
for j in range(out.shape[1]):
out[col, j] += value * matrix[row_i, j]


@ti.kernel
def _csr_matmat_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


def _csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
return raw_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose)


def _csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
return raw_csrmm_taichi(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose)


def _csr_matmat_transpose(
ct, data, indices, indptr, matrix, *, outs, transpose, shape,
):
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
if ad.is_undefined_primal(matrix):
ct_matrix = raw_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0]
return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix)

else:
if type(ct[0]) is ad.Zero:
ct_data = ad.Zero(data)
if ti is not None:

# @ti.kernel
# def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1),
# col_indices: ti.types.ndarray(ndim=1),
# row_ptr: ti.types.ndarray(ndim=1),
# matrix: ti.types.ndarray(ndim=2),
# out: ti.types.ndarray(ndim=2)):
# for row_i in range(row_ptr.shape[0] - 1):
# for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
# col = col_indices[i]
# for j in range(out.shape[1]):
# out[col, j] += values[row_i] * matrix[row_i, j]
#
#
# @ti.kernel
# def _csr_matmat_heter(values: ti.types.ndarray(ndim=1),
# col_indices: ti.types.ndarray(ndim=1),
# row_ptr: ti.types.ndarray(ndim=1),
# matrix: ti.types.ndarray(ndim=2),
# out: ti.types.ndarray(ndim=2)):
# for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
# r = 0.
# for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
# r += values[j] * matrix[col_indices[j], col_k]
# out[row_i, col_k] = r

@ti.kernel
def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
# matrix: (k, n)
# sparse matrix: (m, k)
for j in range(out.shape[1]): # parallize along the n dimension
for row_i in range(row_ptr.shape[0] - 1): # loop along the m dimension
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[i], j] += matrix[row_i, j]


@ti.kernel
def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
m = row_ptr.shape[0] - 1
n = matrix.shape[1]
for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[i], j] += matrix[row_i, j]


@ti.kernel
def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
# matrix: (k, n)
# sparse matrix: (m, k)
m, n = out.shape
for row_i, col_k in ti.ndrange(m, n):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


def _csr_matmat_jvp_matrix(mat_dot, col_indices, row_ptr, matrix, *, outs, transpose, shape):
if transpose:
return _csr_matmat_transpose_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs)
else:
if data.aval.shape[0] == 1: # scalar
ct_data = raw_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0]
ct_data = jnp.sum(ct[0] * ct_data)
else: # heter
matrix = jnp.asarray(matrix)
row, col = csr_to_coo(indices, indptr)
ct_data = (ct[0][row] * matrix[col]).sum(1)
return ct_data, indices, indptr, matrix


def _define_op(cpu_kernel, gpu_kernel):
prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
prim.defjvp(_csr_matmat_jvp_values, None, None, _csr_matmat_jvp_matrix)
prim.def_transpose_rule(_csr_matmat_transpose)
return prim


# transpose heter
_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter,
gpu_kernel=_csr_matmat_transpose_heter)

# no transpose heter
_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter,
gpu_kernel=_csr_matmat_heter)

# transpose homo
_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu,
gpu_kernel=_csr_matmat_transpose_homo_gpu)

# no transpose homo
_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo,
gpu_kernel=_csr_matmat_homo)

# heter CUSPARSE
_csr_matmat_cusparse_p = csr.csr_matmat_p
register_general_batching(_csr_matmat_cusparse_p)
return _csr_matmat_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs)


def _csr_matmat_transpose(
ct, col_indices, row_ptr, matrix, *, outs, transpose, shape,
):
if ad.is_undefined_primal(col_indices) or ad.is_undefined_primal(row_ptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
assert ad.is_undefined_primal(matrix)
ct_matrix = _csr_matmat_transpose_homo_p(col_indices, row_ptr, ct[0],
shape=shape,
transpose=not transpose,
outs=[jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)])
return col_indices, row_ptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix[0])


def _define_op(cpu_kernel, gpu_kernel):
prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
prim.defjvp(None, None, _csr_matmat_jvp_matrix)
prim.def_transpose_rule(_csr_matmat_transpose)
return prim


# # transpose heter
# _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter,
# gpu_kernel=_csr_matmat_transpose_heter)
#
# # no transpose heter
# _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter,
# gpu_kernel=_csr_matmat_heter)

# transpose homo
_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu,
gpu_kernel=_csr_matmat_transpose_homo_gpu)

# no transpose homo
_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo)

# heter CUSPARSE
_csr_matmat_cusparse_p = csr.csr_matmat_p

0 comments on commit c298f48

Please sign in to comment.