Skip to content

Commit

Permalink
Update csrmm
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 7, 2024
1 parent fff64db commit 9f3b9b2
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 124 deletions.
50 changes: 33 additions & 17 deletions brainpy/_src/math/sparse/csr_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def raw_csrmm_taichi(

if indices.shape[0] == 0:
return [jnp.zeros(result_shape, dtype=data.dtype), ]

# homo -> taichi,
# heter -> cusparse
if data.shape[0] != 1:
Expand All @@ -118,11 +119,11 @@ def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
for row_j in range(matrix.shape[0]):
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
out[row_k, col_i] += values[j] * matrix[row_j, col_i]
for row_i in range(row_ptr.shape[0] - 1):
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
col = col_indices[i]
for j in range(out.shape[1]):
out[col, j] += values[row_i] * matrix[row_i, j]


@ti.kernel
Expand All @@ -139,17 +140,32 @@ def _csr_matmat_heter(values: ti.types.ndarray(ndim=1),


@ti.kernel
def _csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
ti.loop_config(serialize=True)
for row_i in range(row_ptr.shape[0] - 1):
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
col = col_indices[i]
for j in range(out.shape[1]):
out[col, j] += value * matrix[row_i, j]


@ti.kernel
def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
for row_j in range(matrix.shape[0]):
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
out[row_k, col_i] += value * matrix[row_j, col_i]
for row_i in range(row_ptr.shape[0] - 1):
for i in range(row_ptr[row_i], row_ptr[row_i + 1]):
col = col_indices[i]
for j in range(out.shape[1]):
out[col, j] += value * matrix[row_i, j]


@ti.kernel
Expand Down Expand Up @@ -213,8 +229,8 @@ def _define_op(cpu_kernel, gpu_kernel):
gpu_kernel=_csr_matmat_heter)

# transpose homo
_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo,
gpu_kernel=_csr_matmat_transpose_homo)
_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu,
gpu_kernel=_csr_matmat_transpose_homo_gpu)

# no transpose homo
_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo,
Expand Down
Loading

0 comments on commit 9f3b9b2

Please sign in to comment.