Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 25, 2024
1 parent 1c28418 commit 68e25f9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
35 changes: 22 additions & 13 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
import numpy as np
from jax import numpy as jnp
from jax.interpreters import ad
from jax.experimental.sparse import csr

from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (XLACustomOp)
from brainpy._src.math.op_register import (XLACustomOp, register_general_batching)
from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm
from brainpy._src.math.sparse.utils import csr_to_coo
from brainpy._src.math.defaults import float_

ti = import_taichi()

Expand Down Expand Up @@ -86,23 +88,26 @@ def raw_event_csrmm_taichi(
return [jnp.zeros(result_shape, dtype=data.dtype), ]

assert matrix.shape[0] == (shape[0] if transpose else shape[1])
if transpose:

# homo -> taichi
# heter -> cusparse
if data.shape[0] != 1:
if matrix.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csr_matmat_transpose_bool_homo_p
# change dtype to float
matrix = matrix.astype(float_)
return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ]
else:
if transpose:
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_transpose_homo_p
else:
prim = _event_csr_matmat_transpose_bool_heter_p
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
else:
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
else:
if matrix.dtype == jnp.bool_:
if data.shape[0] == 1:
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_bool_homo_p
else:
prim = _event_csr_matmat_bool_heter_p
else:
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
return prim(data,
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
return prim(data,
indices,
indptr,
matrix,
Expand Down Expand Up @@ -299,3 +304,7 @@ def _define_op(cpu_kernel, gpu_kernel):
# bool no transpose homo
_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo,
gpu_kernel=_event_csr_matmat_bool_homo)

# heter CUSPARSE
_csr_matmat_cusparse_p = csr.csr_matmat_p
register_general_batching(_csr_matmat_cusparse_p)
8 changes: 3 additions & 5 deletions brainpy/_src/math/sparse/csr_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ def raw_csrmm_taichi(
# _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter,
# gpu_kernel=_csr_matmat_heter)


@ti.kernel
def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
Expand Down Expand Up @@ -203,10 +202,9 @@ def _csr_matmat_transpose(
if ad.is_undefined_primal(col_indices) or ad.is_undefined_primal(row_ptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
assert ad.is_undefined_primal(matrix)
ct_matrix = _csr_matmat_transpose_homo_p(col_indices, row_ptr, ct[0],
shape=shape,
transpose=not transpose,
outs=[jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)])
ct_matrix = raw_csrmm_taichi(jnp.ones(1), col_indices, row_ptr, ct[0],
shape=shape,
transpose=not transpose)
return col_indices, row_ptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix[0])


Expand Down

0 comments on commit 68e25f9

Please sign in to comment.