Skip to content

Commit

Permalink
Optimize event csr matvec with taichi customized op and Add taichi ev…
Browse files Browse the repository at this point in the history
…ent csr matvec benchmark
  • Loading branch information
Routhleck committed Dec 10, 2023
1 parent 0631447 commit 9ed439b
Show file tree
Hide file tree
Showing 2 changed files with 426 additions and 73 deletions.
193 changes: 120 additions & 73 deletions brainpy/_src/math/event/_csr_matvec_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'csrmv_taichi'
]

### CPU

@ti.kernel
def _event_csr_matvec_transpose_bool_cpu(values: ti.types.ndarray(ndim=1),
Expand Down Expand Up @@ -115,93 +116,114 @@ def _event_csr_matvec_cpu(values: ti.types.ndarray(ndim=1),
r += values[j]
out[row_i] = r

### GPU
# homo

@ti.kernel
def _event_csr_matvec_transpose_bool_gpu(values: ti.types.ndarray(ndim=1),
def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
value = values[0]
for row_i in ti.ndrange(indptr.shape[0] - 1):
if events[row_i]:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += value


@ti.kernel
def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
for row_i in range(indptr.shape[0] - 1):
if events[row_i]:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += value

else:
for row_i in range(indptr.shape[0] - 1):
if events[row_i]:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += values[j]
value = values[0]
for row_i in ti.ndrange(indptr.shape[0] - 1):
if events[row_i] > 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += value


@ti.kernel
def _event_csr_matvec_transpose_gpu(values: ti.types.ndarray(ndim=1),
def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
for row_i in range(indptr.shape[0] - 1):
if events[row_i] > 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += value

else:
for row_i in range(indptr.shape[0] - 1):
if events[row_i] > 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += values[j]

value = values[0]
for row_i in ti.ndrange(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]]:
r += value
out[row_i] = r

@ti.kernel
def _event_csr_matvec_bool_gpu(values: ti.types.ndarray(ndim=1),
def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]]:
r += value
out[row_i] = r
value = values[0]
for row_i in ti.ndrange(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] > 0.:
r += value
out[row_i] = r

else:
for row_i in range(indptr.shape[0] - 1):
r = 0.
# heter

@ti.kernel
def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(indptr.shape[0] - 1):
if events[row_i]:
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]]:
r += values[j]
out[row_i] = r
out[indices[j]] += values[j]


@ti.kernel
def _event_csr_matvec_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
for row_i in range(indptr.shape[0] - 1):
r = 0.
def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(indptr.shape[0] - 1):
if events[row_i] > 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] > 0.:
r += value
out[row_i] = r
out[indices[j]] += values[j]

else:
for row_i in range(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] > 0.:
r += values[j]
out[row_i] = r

@ti.kernel
def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]]:
r += values[j]
out[row_i] = r

@ti.kernel
def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
indices: ti.types.ndarray(ndim=1),
indptr: ti.types.ndarray(ndim=1),
events: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(indptr.shape[0] - 1):
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] > 0.:
r += values[j]
out[row_i] = r


def _event_csr_matvec_jvp_values(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
Expand Down Expand Up @@ -313,14 +335,26 @@ def csrmv_taichi(

if transpose:
if events.dtype == jnp.bool_:
prim = _event_csrmv_transpose_bool_p
if events.shape[0] == 1:
prim = _event_csrmv_transpose_bool_homo_p
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
prim = _event_csrmv_transpose_p
if events.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
else:
if events.dtype == jnp.bool_:
prim = _event_csrmv_bool_p
if events.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
prim = _event_csrmv_p
if events.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p

# computing
return prim(data,
Expand All @@ -339,14 +373,27 @@ def _define_op(cpu_kernel, gpu_kernel):
return prim


# transpose bool
_event_csrmv_transpose_bool_p = _define_op(_event_csr_matvec_transpose_bool_cpu, _event_csr_matvec_transpose_bool_gpu)
# transpose bool homo
_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_cpu, _event_csr_matvec_transpose_bool_homo_gpu)

# transpose homo
_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_cpu, _event_csr_matvec_transpose_homo_gpu)

# not transpose bool homo
_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_cpu, _event_csr_matvec_bool_homo_gpu)

# not transpose homo
_event_csrmv_homo_p = _define_op(_event_csr_matvec_cpu, _event_csr_matvec_homo_gpu)

# transpose bool heter
_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_cpu, _event_csr_matvec_transpose_bool_heter_gpu)

# transpose heter
_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_cpu, _event_csr_matvec_transpose_heter_gpu)

# transpose
_event_csrmv_transpose_p = _define_op(_event_csr_matvec_transpose_cpu, _event_csr_matvec_transpose_gpu)
# not transpose bool heter
_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_cpu, _event_csr_matvec_bool_heter_gpu)

# not transpose bool
_event_csrmv_bool_p = _define_op(_event_csr_matvec_bool_cpu, _event_csr_matvec_bool_gpu)
# not transpose heter
_event_csrmv_heter_p = _define_op(_event_csr_matvec_cpu, _event_csr_matvec_heter_gpu)

# not transpose
_event_csrmv_p = _define_op(_event_csr_matvec_cpu, _event_csr_matvec_gpu)
Loading

0 comments on commit 9ed439b

Please sign in to comment.