Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Implementing event-driven sparse matrix @ matrix operators #613

Merged
merged 27 commits into from
May 27, 2024
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b06e817
[math] Add sparse `matrix@matrix` operators
Routhleck Jan 30, 2024
90516c1
[math] Implement csr matrix @ matrix operator
Routhleck Feb 1, 2024
787d6a8
Update _csr_mm.py
Routhleck Feb 1, 2024
5de7807
[math] Implement event csr matmat
Routhleck Feb 3, 2024
8b403c5
[math] Support data is homo for csr matmat op
Routhleck Feb 4, 2024
eb09d0e
Update _csr_matmat.py
Routhleck Feb 4, 2024
a212497
Format codes
Routhleck Feb 4, 2024
343d8a0
Format codes and import
Routhleck Feb 17, 2024
28d41ff
Merge pull request #621 from brainpy/master
Routhleck Feb 17, 2024
7fe78b2
[math] Replace csr matmat heter operators with CUSPARSE
Routhleck Feb 17, 2024
b8ef61a
Fix csr matmat bugs
Routhleck Feb 17, 2024
c37ce92
Fix bug
Routhleck Feb 17, 2024
9ea9800
Update operator selection strategy for csr matmat
Routhleck Feb 17, 2024
6b7f5fe
Merge branch 'master' of https://github.com/brainpy/BrainPy into add-…
Routhleck Mar 3, 2024
8737b93
[math] Update event csrmm and csrmm
Routhleck Mar 3, 2024
b9e5584
Update sparse.py
Routhleck Mar 3, 2024
861b340
Update event csr matvec
Routhleck Mar 4, 2024
fff64db
Update new transpose taichi kernels of csrmm and event csrmm
Routhleck Mar 7, 2024
9f3b9b2
Update csrmm
Routhleck Mar 7, 2024
116df57
Merge branch 'master' into add-matmat-op
chaoming0625 Mar 8, 2024
c298f48
accelerate csrmm homo
chaoming0625 Mar 8, 2024
f355d18
upgrade
chaoming0625 Mar 8, 2024
c92ace4
Update csr_matmat_VS_cusparse_csr_matmat.py
Routhleck Mar 8, 2024
2589820
update
chaoming0625 Mar 9, 2024
1c28418
updates
chaoming0625 Mar 10, 2024
68e25f9
Fix bugs
Routhleck May 25, 2024
23ef2fc
Merge branch 'master' into add-matmat-op
Routhleck May 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions brainpy/_src/math/event/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .csr_matvec import *
from .csr_matmat import *

310 changes: 310 additions & 0 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
# -*- coding: utf-8 -*-


from typing import Union, Tuple

import jax
import numpy as np
from jax import numpy as jnp
from jax.interpreters import ad
from jax.experimental.sparse import csr

from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (XLACustomOp, register_general_batching)
from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm
from brainpy._src.math.sparse.utils import csr_to_coo
from brainpy._src.math.defaults import float_

ti = import_taichi()

__all__ = [
'csrmm',
]


def csrmm(
data: Union[float, jnp.ndarray, Array],
indices: Union[jnp.ndarray, Array],
indptr: Union[jnp.ndarray, Array],
matrix: Union[jnp.ndarray, Array],
*,
shape: Tuple[int, int],
transpose: bool = False,
):
"""Product of CSR sparse matrix and a dense event matrix.

Args:
data : array of shape ``(nse,)``, float.
indices : array of shape ``(nse,)``
indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``
B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and
dtype ``data.dtype``
shape : length-2 tuple representing the matrix shape
transpose : boolean specifying whether to transpose the sparse matrix
before computing.

Returns:
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product product.
"""
return raw_event_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0]


def raw_event_csrmm_taichi(
data: Union[float, jnp.ndarray, Array],
indices: Union[jnp.ndarray, Array],
indptr: Union[jnp.ndarray, Array],
matrix: Union[jnp.ndarray, Array],
*,
shape: Tuple[int, int],
transpose: bool = False,
):
assert len(shape) == 2

data = jnp.atleast_1d(data)
if np.ndim(data) == 1:
if data.shape[0] not in [1, indices.shape[0]]:
raise ValueError('The size of data should be 1 or be consistent with indices.'
f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.')

indices = as_jax(indices)
indptr = as_jax(indptr)
matrix = as_jax(matrix)

assert data.ndim == indices.ndim == indptr.ndim == 1
assert matrix.ndim == 2
assert indptr.shape[0] == shape[0] + 1
if not jnp.issubdtype(indices.dtype, jnp.integer):
raise ValueError('indices should be a 1D vector with integer type.')
if not jnp.issubdtype(indptr.dtype, jnp.integer):
raise ValueError('indptr should be a 1D vector with integer type.')

out_shape = shape[1] if transpose else shape[0]
result_shape = (out_shape, matrix.shape[1])
# if the shape of indices is (0,), then we return a zero matrix
if indices.shape[0] == 0:
return [jnp.zeros(result_shape, dtype=data.dtype), ]

assert matrix.shape[0] == (shape[0] if transpose else shape[1])

# homo -> taichi
# heter -> cusparse
if data.shape[0] != 1:
if matrix.dtype == jnp.bool_:
# change dtype to float
matrix = matrix.astype(float_)
return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ]
else:
if transpose:
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_transpose_homo_p
else:
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
else:
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_bool_homo_p
else:
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
return prim(data,
indices,
indptr,
matrix,
outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)],
transpose=transpose,
shape=shape)


# taichi kernels

@ti.kernel
def _event_csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
out[row_k, col_i] += values[j] * matrix[row_j, col_i]


@ti.kernel
def _event_csr_matmat_transpose_bool_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
out[row_k, col_i] += values[j] * matrix[row_j, col_i]


@ti.kernel
def _event_csr_matmat_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


@ti.kernel
def _event_csr_matmat_bool_heter(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


@ti.kernel
def _event_csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
out[row_k, col_i] += value * matrix[row_j, col_i]


@ti.kernel
def _event_csr_matmat_transpose_bool_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
out[row_k, col_i] += value * matrix[row_j, col_i]


@ti.kernel
def _event_csr_matmat_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


@ti.kernel
def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose)


def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
return normal_csrmm(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose)


def _event_csr_matmat_transpose(
ct, data, indices, indptr, matrix, *, outs, transpose, shape,
):
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
if ad.is_undefined_primal(matrix):
ct_matrix = raw_event_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0]
return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix)

else:
if type(ct[0]) is ad.Zero:
ct_data = ad.Zero(data)
else:
if data.aval.shape[0] == 1: # scalar
ct_data = \
raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0]
ct_data = jnp.sum(ct[0] * ct_data)
else: # heter
matrix = jnp.asarray(matrix)
row, col = csr_to_coo(indices, indptr)
ct_data = (ct[0][row] * matrix[col]).sum(1)
return ct_data, indices, indptr, matrix


def _define_op(cpu_kernel, gpu_kernel):
prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel)
prim.defjvp(_event_csr_matmat_jvp_values, None, None, _event_csr_matmat_jvp_matrix)
prim.def_transpose_rule(_event_csr_matmat_transpose)
return prim


# transpose heter
_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter,
gpu_kernel=_event_csr_matmat_transpose_heter)

# no transpose heter
_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter,
gpu_kernel=_event_csr_matmat_heter)

# transpose homo
_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo,
gpu_kernel=_event_csr_matmat_transpose_homo)

# no transpose homo
_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo,
gpu_kernel=_event_csr_matmat_homo)

# bool transpose heter
_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter,
gpu_kernel=_event_csr_matmat_transpose_bool_heter)

# bool no transpose heter
_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter,
gpu_kernel=_event_csr_matmat_bool_heter)

# bool transpose homo
_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo,
gpu_kernel=_event_csr_matmat_transpose_bool_homo)

# bool no transpose homo
_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo,
gpu_kernel=_event_csr_matmat_bool_homo)

# heter CUSPARSE
_csr_matmat_cusparse_p = csr.csr_matmat_p
register_general_batching(_csr_matmat_cusparse_p)
10 changes: 2 additions & 8 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
@@ -131,21 +131,15 @@ def raw_csrmv_taichi(
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p
return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)

# computing
return prim(data,
Loading