Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Optimize event csr matvec and matmat operators #675

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 55 additions & 16 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@
import jax
import numpy as np
from jax import numpy as jnp
from jax.interpreters import ad
from jax.experimental.sparse import csr
from jax.interpreters import ad

from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.defaults import float_
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import (XLACustomOp, register_general_batching)
from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm
from brainpy._src.math.sparse.utils import csr_to_coo
from brainpy._src.math.defaults import float_

ti = import_taichi()

Expand Down Expand Up @@ -99,21 +98,21 @@ def raw_event_csrmm_taichi(
else:
if transpose:
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_transpose_homo_p
prim = _event_csr_matmat_transpose_bool_homo_p
else:
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
prim = _event_csr_matmat_transpose_homo_p
else:
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_bool_homo_p
else:
return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
prim = _event_csr_matmat_homo_p
return prim(data,
indices,
indptr,
matrix,
outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)],
transpose=transpose,
shape=shape)
indices,
indptr,
matrix,
outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)],
transpose=transpose,
shape=shape)


# taichi kernels
Expand Down Expand Up @@ -234,8 +233,38 @@ def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1),
out[row_i, col_k] = r * value


@ti.kernel
def _event_csr_matmat_dw(dy: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
dw: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(dw.shape[0]):
temp_sum = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
for col_k in range(matrix.shape[1]):
if matrix[col_indices[row_j], col_k] != 0.:
temp_sum += dy[row_j] * matrix[col_indices[row_j], col_k]
dw[row_i] = temp_sum


@ti.kernel
def _event_csr_matmat_bool_dw(dy: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
matrix: ti.types.ndarray(ndim=2),
dw: ti.types.ndarray(ndim=1)):
for row_i in ti.ndrange(dw.shape[0]):
temp_sum = 0.
for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]):
for col_k in range(matrix.shape[1]):
if matrix[col_indices[row_j], col_k]:
temp_sum += dy[row_j] * matrix[col_indices[row_j], col_k]
dw[row_i] = temp_sum


def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose)
return raw_event_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose)


def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape):
Expand All @@ -261,8 +290,12 @@ def _event_csr_matmat_transpose(
ct_data = jnp.sum(ct[0] * ct_data)
else: # heter
matrix = jnp.asarray(matrix)
row, col = csr_to_coo(indices, indptr)
ct_data = (ct[0][row] * matrix[col]).sum(1)
if matrix.dtype == jnp.bool_:
prim = _event_csr_matmat_bool_dw
else:
prim = _event_csr_matmat_dw
ct_data = prim(ct[0], indices, indptr, matrix,
out=jax.ShapeDtypeStruct((data.aval.shape[0], matrix.shape[1]), data.aval.dtype))[0]
return ct_data, indices, indptr, matrix


Expand Down Expand Up @@ -305,6 +338,12 @@ def _define_op(cpu_kernel, gpu_kernel):
_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo,
gpu_kernel=_event_csr_matmat_bool_homo)

_event_csr_matmat_dw_p = XLACustomOp(cpu_kernel=_event_csr_matmat_dw,
gpu_kernel=_event_csr_matmat_dw)

_event_csr_matmat_bool_dw_p = XLACustomOp(cpu_kernel=_event_csr_matmat_bool_dw,
gpu_kernel=_event_csr_matmat_bool_dw)

# heter CUSPARSE
_csr_matmat_cusparse_p = csr.csr_matmat_p
register_general_batching(_csr_matmat_cusparse_p)
register_general_batching(_csr_matmat_cusparse_p)
109 changes: 95 additions & 14 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import XLACustomOp
from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse.utils import csr_to_coo
from brainpy.errors import PackageMissingError

__all__ = [
Expand Down Expand Up @@ -131,15 +130,21 @@ def raw_csrmv_taichi(
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p

# computing
return prim(data,
Expand Down Expand Up @@ -201,7 +206,7 @@ def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1),
for row_i in range(indptr.shape[0] - 1):
if events[row_i] != 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += value
out[indices[j]] += value * events[row_i]


@ti.kernel
Expand All @@ -214,7 +219,7 @@ def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1),
for row_i in range(indptr.shape[0] - 1):
if events[row_i] != 0.:
for j in range(indptr[row_i], indptr[row_i + 1]):
out[indices[j]] += values[j]
out[indices[j]] += values[j] * events[row_i]


@ti.kernel
Expand Down Expand Up @@ -260,7 +265,7 @@ def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1),
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] != 0.:
r += value
r += value * events[indices[j]]
out[row_i] = r


Expand All @@ -275,7 +280,7 @@ def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1),
r = 0.
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] != 0.:
r += values[j]
r += values[j] * events[indices[j]]
out[row_i] = r


Expand Down Expand Up @@ -318,7 +323,7 @@ def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1),
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
out[indices[j]] += value
out[indices[j]] += value * events[row_i]
j += 32


Expand Down Expand Up @@ -365,7 +370,7 @@ def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1),
end_index = indptr[row_i + 1]
while j < end_index:
if events[indices[j]] != 0.:
r += value
r += value * events[indices[j]]
j += 32
out[row_i] += r # TODO: warp-level primitive

Expand Down Expand Up @@ -400,7 +405,7 @@ def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1),
j = indptr[row_i] + index
end_index = indptr[row_i + 1]
while j < end_index:
out[indices[j]] += values[j]
out[indices[j]] += values[j] * events[row_i]
j += 32


Expand Down Expand Up @@ -437,13 +442,61 @@ def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),
end_index = indptr[row_i + 1]
while j < end_index:
if events[indices[j]] != 0.:
r += values[j]
r += values[j] * events[indices[j]]
j += 32
out[row_i] += r # TODO: warp-level primitive


@ti.kernel
def _event_csr_matvec_dW_transpose(dy: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
dw: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i] != 0.:
for j in range(indptr[i], indptr[i + 1]):
dw[j] = dy[indices[j]]


@ti.kernel
def _event_csr_matvec_dW_bool_transpose(dy: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
dw: ti.types.ndarray()):
for i in range(events.shape[0]):
if events[i]:
for j in range(indptr[i], indptr[i + 1]):
dw[j] = dy[indices[j]]


@ti.kernel
def _event_csr_matvec_dW(dy: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
dw: ti.types.ndarray()):
for row_i in range(indptr.shape[0] - 1):
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]] != 0.:
dw[j] = dy[row_i]


@ti.kernel
def _event_csr_matvec_dW_bool(dy: ti.types.ndarray(),
indices: ti.types.ndarray(),
indptr: ti.types.ndarray(),
events: ti.types.ndarray(),
dw: ti.types.ndarray()):
for row_i in range(indptr.shape[0] - 1):
for j in range(indptr[row_i], indptr[row_i + 1]):
if events[indices[j]]:
dw[j] = dy[row_i]


def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape):
return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)
return raw_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose)


def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape):
Expand All @@ -466,8 +519,20 @@ def _event_csr_matvec_transpose_taichi(
ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
ct_values = jnp.inner(ct[0], ct_values)
else: # heterogeneous values
row, col = csr_to_coo(indices, indptr)
ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
# row, col = csr_to_coo(indices, indptr)
# ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
if transpose:
if events.dtype == jnp.bool_:
prim = _event_csrmv_dW_bool_p_transpose
else:
prim = _event_csrmv_dW_p_transpose
else:
if events.dtype == jnp.bool_:
prim = _event_csrmv_dW_bool_p
else:
prim = _event_csrmv_dW_p
ct_values = prim(ct[0], indices, indptr, events,
outs=[jax.ShapeDtypeStruct((values.aval.shape[0],), values.aval.dtype)])[0]
return ct_values, indices, indptr, events


Expand Down Expand Up @@ -509,3 +574,19 @@ def _define_op(cpu_kernel, gpu_kernel):
# not transpose heter
_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu,
_event_csr_matvec_heter_gpu)

# compute dW transpose
_event_csrmv_dW_p_transpose = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_transpose,
gpu_kernel=_event_csr_matvec_dW_transpose)

# compute dW bool transpose
_event_csrmv_dW_bool_p_transpose = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool_transpose,
gpu_kernel=_event_csr_matvec_dW_bool_transpose)

# compute dW
_event_csrmv_dW_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW,
gpu_kernel=_event_csr_matvec_dW)

# compute dW bool
_event_csrmv_dW_bool_p = XLACustomOp(cpu_kernel=_event_csr_matvec_dW_bool,
gpu_kernel=_event_csr_matvec_dW_bool)
Loading
Loading