Skip to content

Commit

Permalink
Fix csr matmat bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Feb 17, 2024
1 parent 7fe78b2 commit b8ef61a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
26 changes: 17 additions & 9 deletions brainpy/_src/math/sparse/_csr_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (XLACustomOp)
from brainpy._src.math.op_register import (XLACustomOp, register_general_batching)
from brainpy._src.math.sparse._utils import csr_to_coo

ti = import_taichi()
Expand Down Expand Up @@ -47,7 +47,19 @@ def csrmm(
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product product.
"""
return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0]
data = jnp.atleast_1d(data)
out_shape = shape[1] if transpose else shape[0]
result_shape = (out_shape, matrix.shape[1])
# if the shape of indices is (0,), then we return a zero matrix

if data.shape[0] != 1:
if indices.shape[0] == 0:
return jnp.zeros(result_shape, dtype=data.dtype)
return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)
else:
if indices.shape[0] == 0:
return [jnp.zeros(result_shape, dtype=data.dtype), ]
return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0]


def raw_csrmm_taichi(
Expand All @@ -65,8 +77,6 @@ def raw_csrmm_taichi(
indptr = as_jax(indptr)
matrix = as_jax(matrix)

data = jnp.atleast_1d(data)

if matrix.dtype == jnp.bool_:
matrix = as_jax(matrix, dtype=data.dtype)

Expand All @@ -75,7 +85,7 @@ def raw_csrmm_taichi(
f'But we got {data.dtype} != {matrix.dtype}.')
assert data.ndim == indices.ndim == indptr.ndim == 1
assert matrix.ndim == 2
data = jnp.atleast_1d(data)

if np.ndim(data) == 1:
if data.shape[0] not in [1, indices.shape[0]]:
raise ValueError('The size of data should be 1 or be consistent with indices.'
Expand All @@ -88,9 +98,6 @@ def raw_csrmm_taichi(

out_shape = shape[1] if transpose else shape[0]
result_shape = (out_shape, matrix.shape[1])
# if the shape of indices is (0,), then we return a zero matrix
if indices.shape[0] == 0:
return [jnp.zeros(result_shape, dtype=data.dtype), ]

assert matrix.shape[0] == (shape[0] if transpose else shape[1])
if data.shape[0] != 1:
Expand Down Expand Up @@ -298,4 +305,5 @@ def _define_op(cpu_kernel, gpu_kernel):
gpu_kernel=_csr_matmat_homo_gpu)

# heter CUSPARSE
_csr_matmat_heter_p = csr.csr_matmat_p
_csr_matmat_heter_p = csr.csr_matmat_p
register_general_batching(_csr_matmat_heter_p)
2 changes: 2 additions & 0 deletions brainpy/_src/math/sparse/tests/test_csrmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ def test_heter(self, transpose, shape):
r1 = (dense.T @ matrix) if transpose else (dense @ matrix)
r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix,
shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)
print(r2)
print(r1.shape, '-', r2.shape)
c = bm.allclose(r1, r2, equal_nan=True)
if not c:
print(r1 - r2)
Expand Down

0 comments on commit b8ef61a

Please sign in to comment.