Skip to content

Commit

Permalink
Update event csr matvec
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 4, 2024
1 parent b9e5584 commit 861b340
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 24 deletions.
24 changes: 12 additions & 12 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
r += val * matrix[row_j, col_i]
out[row_k, col_i] = r


Expand All @@ -151,7 +151,7 @@ def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1),
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
r += val
r += val * matrix[row_j, col_i]
out[row_k, col_i] = r


Expand Down Expand Up @@ -196,9 +196,9 @@ def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r
out[row_k, col_i] = r * value


@ti.kernel
Expand All @@ -214,9 +214,9 @@ def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1),
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r
out[row_k, col_i] = r * value


@ti.kernel
Expand Down Expand Up @@ -264,7 +264,7 @@ def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
val = values[j] * matrix[row_j, col_i]
r += val
out[row_k, col_i] = r

Expand All @@ -282,7 +282,7 @@ def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
val = 0.
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
val = values[j]
val = values[j] * matrix[row_j, col_i]
r += val
out[row_k, col_i] = r

Expand Down Expand Up @@ -328,9 +328,9 @@ def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
if matrix[row_j, col_i] != 0.:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r
out[row_k, col_i] = r * value


@ti.kernel
Expand All @@ -346,9 +346,9 @@ def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
if matrix[row_j, col_i]:
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r
out[row_k, col_i] = r * value


@ti.kernel
Expand Down
10 changes: 2 additions & 8 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,21 +131,15 @@ def raw_csrmv_taichi(
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p
return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)

# computing
return prim(data,
Expand Down
8 changes: 4 additions & 4 deletions brainpy/_src/math/sparse/csr_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
for row_j in range(matrix.shape[0]):
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r
out[row_k, col_i] = r * value


@ti.kernel
Expand Down Expand Up @@ -225,9 +225,9 @@ def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
for row_j in range(matrix.shape[0]):
for j in range(row_ptr[row_j], row_ptr[row_j + 1]):
if col_indices[j] == row_k:
r += value * matrix[row_j, col_i]
r += matrix[row_j, col_i]
break
out[row_k, col_i] = r
out[row_k, col_i] = r * value


@ti.kernel
Expand Down

0 comments on commit 861b340

Please sign in to comment.