Skip to content

Commit

Permalink
[math] fix and update taichi jitconn operators
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Dec 24, 2023
1 parent f840ada commit a14b6a5
Show file tree
Hide file tree
Showing 10 changed files with 1,991 additions and 2,308 deletions.
11 changes: 5 additions & 6 deletions brainpy/_src/math/event/_csr_matvec_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from brainpy._src.math.op_register import XLACustomOp
from brainpy._src.math.sparse._csr_mv_taichi import csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse._utils import csr_to_coo
from brainpy._src.math.taichi_support import warp_reduce_sum

ti = import_taichi()

Expand Down Expand Up @@ -197,7 +196,7 @@ def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
# should be improved, since the atomic_add for each thread is not
# very efficient. Instead, the warp-level reduction primitive
# should be used.
# see ``warp_reduce_sum()`` function in taichi_support.py.
# see ``warp_reduce_sum()`` function in tifunc.py.
# However, currently Taichi does not support general warp-level primitives.


Expand All @@ -218,7 +217,7 @@ def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
if events[indices[j]]:
r += value
j += 32
out[row_i] += r
out[row_i] += r # TODO: warp-level primitive


@ti.kernel
Expand All @@ -238,7 +237,7 @@ def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
if events[indices[j]] != 0.:
r += value
j += 32
out[row_i] += r
out[row_i] += r # TODO: warp-level primitive


@ti.kernel
Expand Down Expand Up @@ -291,7 +290,7 @@ def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
if events[indices[j]]:
r += values[j]
j += 32
out[row_i] += r
out[row_i] += r # TODO: warp-level primitive


@ti.kernel
Expand All @@ -310,7 +309,7 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
if events[indices[j]] != 0.:
r += values[j]
j += 32
out[row_i] += r
out[row_i] += r # TODO: warp-level primitive


def _event_csr_matvec_jvp_values(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
Expand Down
Loading

0 comments on commit a14b6a5

Please sign in to comment.