Skip to content

Commit

Permalink
[math] Update event csrmm and csrmm
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 3, 2024
1 parent 6b7f5fe commit 8737b93
Show file tree
Hide file tree
Showing 9 changed files with 736 additions and 212 deletions.
1 change: 1 addition & 0 deletions brainpy/_src/math/event/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .csr_matvec import *
from .csr_matmat import *

Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (XLACustomOp)
from brainpy._src.math.sparse._csr_mm import raw_csrmm_taichi as normal_csrmm
from brainpy._src.math.sparse._utils import csr_to_coo
from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm
from brainpy._src.math.sparse.utils import csr_to_coo

ti = import_taichi()

Expand Down Expand Up @@ -125,17 +125,16 @@ def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i in range(out.shape[1]):
for row_k in range(out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
out[row_k, col_i] = r
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
out[row_k, col_i] = r


@ti.kernel
Expand All @@ -144,17 +143,16 @@ def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i in range(out.shape[1]):
for row_k in range(out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
out[row_k, col_i] = r
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
out[row_k, col_i] = r


@ti.kernel
Expand All @@ -163,13 +161,12 @@ def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i in range(out.shape[0]):
for col_k in range(out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += values[row_j]
out[row_i, col_k] = r
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


@ti.kernel
Expand All @@ -178,13 +175,12 @@ def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i in range(out.shape[0]):
for col_k in range(out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += values[row_j]
out[row_i, col_k] = r
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


@ti.kernel
Expand All @@ -194,16 +190,15 @@ def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i in range(out.shape[1]):
for row_k in range(out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
break
out[row_k, col_i] = r
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
break
out[row_k, col_i] = r


@ti.kernel
Expand All @@ -213,16 +208,15 @@ def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i in range(out.shape[1]):
for row_k in range(out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
break
out[row_k, col_i] = r
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
break
out[row_k, col_i] = r


@ti.kernel
Expand All @@ -232,13 +226,12 @@ def _event_csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i in range(out.shape[0]):
for col_k in range(out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


@ti.kernel
Expand All @@ -248,13 +241,12 @@ def _event_csr_matmat_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i in range(out.shape[0]):
for col_k in range(out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


# GPU kernels
Expand All @@ -265,17 +257,16 @@ def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i in range(out.shape[1]):
for row_k in range(out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
out[row_k, col_i] = r
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
out[row_k, col_i] = r


@ti.kernel
Expand All @@ -284,17 +275,16 @@ def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for col_i in range(out.shape[1]):
for row_k in range(out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
out[row_k, col_i] = r
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
out[row_k, col_i] = r


@ti.kernel
Expand All @@ -303,13 +293,12 @@ def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i in range(out.shape[0]):
for col_k in range(out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += values[row_j]
out[row_i, col_k] = r
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


@ti.kernel
Expand All @@ -318,13 +307,12 @@ def _event_csr_matmat_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
for row_i in range(out.shape[0]):
for col_k in range(out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += values[row_j]
out[row_i, col_k] = r
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += values[row_j] * matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r


@ti.kernel
Expand All @@ -334,16 +322,15 @@ def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i in range(out.shape[1]):
for row_k in range(out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
break
out[row_k, col_i] = r
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
break
out[row_k, col_i] = r


@ti.kernel
Expand All @@ -353,16 +340,15 @@ def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for col_i in range(out.shape[1]):
for row_k in range(out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
break
out[row_k, col_i] = r
for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]):
r = 0.
for row_j in range(matrix.shape[0]):
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
break
out[row_k, col_i] = r


@ti.kernel
Expand All @@ -372,13 +358,12 @@ def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i in range(out.shape[0]):
for col_k in range(out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k] != 0.:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


@ti.kernel
Expand All @@ -388,13 +373,12 @@ def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
out: ti.types.ndarray(ndim=2)):
value = values[0]
for row_i in range(out.shape[0]):
for col_k in range(out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value
for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]):
r = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
if matrix[col_indices[row_j], col_k]:
r += matrix[col_indices[row_j], col_k]
out[row_i, col_k] = r * value


def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
Expand Down
Loading

0 comments on commit 8737b93

Please sign in to comment.