Skip to content

Commit

Permalink
Test sparse csr matvec using taichi customized op
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 1, 2023
1 parent 5482a71 commit 05cf134
Show file tree
Hide file tree
Showing 4 changed files with 417 additions and 57 deletions.
15 changes: 10 additions & 5 deletions brainpy/_src/math/event/_csr_matvec_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,14 @@ def _event_csr_matvec_jvp(
events,
shape=shape,
transpose=transpose,)

else:
dr = normal_csrmv_taichi(values_dot,
indices,
indptr,
events_dot,
shape=shape,
transpose=transpose)

return r, dr

def _event_csr_matvec_transpose(ct,
Expand All @@ -254,14 +261,14 @@ def _event_csr_matvec_transpose(ct,
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
if ad.is_undefined_primal(events):
ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose = transpose)[0]
ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0]
return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events)
else:
if type(ct[0]) is ad.Zero:
ct_values = ad.Zero(values)
else:
if values.aval.shape[0] == 1: # scalar
ct_values = csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose =transpose)[0]
ct_values = csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
ct_values = jnp.inner(ct[0], ct_values)
else: # heterogeneous values
row, col = csr_to_coo(indices, indptr)
Expand Down Expand Up @@ -345,8 +352,6 @@ def csrmv_taichi(
if indices.shape[0] == 0:
return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)

bool_param_list = jnp.array([transpose, events.dtype == jnp.bool_, data.shape[0] > 1])

prim = None

if transpose:
Expand Down
51 changes: 42 additions & 9 deletions brainpy/_src/math/event/tests/test_event_csrmv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,48 @@ def func(*args, **kwargs):
return r.sum()
return func

# transposes = [True, False]
# shapes = [(100, 200),
# (200, 200),
# (200, 100),
# (10, 1000),
# (2, 10000),
# (1000, 10),
# (10000, 2)]
# homo_datas = [-1., 0., 1.]
transposes = [True, False]
shapes = [(100, 200),
(200, 200),
(200, 100),
(10, 1000),
(2, 10000),
(1000, 10),
(10000, 2)]
homo_datas = [-1., 0., 1.]

def test_homo(shape, transpose, homo_data):
print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
rng = bm.random.RandomState()
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
events = rng.random(shape[0] if transpose else shape[1]) < 0.1
heter_data = bm.ones(indices.shape) * homo_data

r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
r2 = bm.event.csrmv_taichi(homo_data, indices, indptr, events, shape=shape, transpose=transpose)

assert(bm.allclose(r1, r2[0]))

bm.clear_buffer_memory()

def test_heter(shape, transpose):
print(f'test_heter: shape = {shape}, transpose = {transpose}')
rng = bm.random.RandomState()
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
heter_data = bm.as_jax(rng.random(indices.shape))

r1 = bm.event.csrmv(heter_data, indices, indptr, events,
shape=shape, transpose=transpose)
r2 = bm.event.csrmv_taichi(heter_data, indices, indptr, events,
shape=shape, transpose=transpose)

assert(bm.allclose(r1, r2[0]))

bm.clear_buffer_memory()


class Test_event_csr_matvec(parameterized.TestCase):
def __init__(self, *args, platform='cpu', **kwargs):
Expand Down
42 changes: 20 additions & 22 deletions brainpy/_src/math/sparse/_csr_mv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
]

@ti.kernel
def _sparse_csr_matvec_cpu_transpose(values: ti.types.ndarray(ndim=1),
def _sparse_csr_matvec_transpose_cpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
ti.loop_config(serialize=True)
for row_i in range(vector.shape[0]):
for row_i in range(row_ptr.shape[0] - 1):
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[j]] += value * vector[row_i]

else:
ti.loop_config(serialize=True)
for row_i in range(vector.shape[0]):
for row_i in range(row_ptr.shape[0] - 1):
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[j]] += values[j] * vector[row_i]

Expand All @@ -52,35 +52,35 @@ def _sparse_csr_matvec_cpu(values: ti.types.ndarray(ndim=1),
if values.shape[0] == 1:
value = values[0]
ti.loop_config(serialize=True)
for row_i in range(vector.shape[0]):
for row_i in range(row_ptr.shape[0] - 1):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += value * vector[col_indices[j]]
out[row_i] = r

else:
ti.loop_config(serialize=True)
for row_i in range(vector.shape[0]):
for row_i in range(row_ptr.shape[0] - 1):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += values[j] * vector[col_indices[j]]
out[row_i] = r


@ti.kernel
def _sparse_csr_matvec_gpu_transpose(values: ti.types.ndarray(ndim=1),
def _sparse_csr_matvec_transpose_gpu(values: ti.types.ndarray(ndim=1),
col_indices: ti.types.ndarray(ndim=1),
row_ptr: ti.types.ndarray(ndim=1),
vector: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
for row_i in range(vector.shape[0]):
for row_i in range(row_ptr.shape[0] - 1):
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[j]] += value * vector[row_i]

else:
for row_i in range(vector.shape[0]):
for row_i in range(row_ptr.shape[0] - 1):
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
out[col_indices[j]] += values[j] * vector[row_i]

Expand All @@ -92,14 +92,14 @@ def _sparse_csr_matvec_gpu(values: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
if values.shape[0] == 1:
value = values[0]
for row_i in range(vector.shape[0]):
for row_i in range(row_ptr.shape[0] - 1):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += value * vector[col_indices[j]]
out[row_i] = r

else:
for row_i in range(vector.shape[0]):
for row_i in range(row_ptr.shape[0] - 1):
r = 0.
for j in range(row_ptr[row_i], row_ptr[row_i + 1]):
r += values[j] * vector[col_indices[j]]
Expand Down Expand Up @@ -132,12 +132,11 @@ def _sparse_csr_matvec_jvp(
transpose=transpose)
elif type(vector_dot) is ad.Zero:
dr = csrmv_taichi(values_dot,
col_indices_dot,
row_ptr_dot,
col_indices,
row_ptr,
vector,
shape=shape,
transpose=transpose)

return r, dr

def _sparse_csr_matvec_transpose(
Expand All @@ -152,17 +151,17 @@ def _sparse_csr_matvec_transpose(
ct[0],
shape=shape,
transpose=not transpose)[0]
return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector)
return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector)

else:
if type(ct) is ad.Zero:
ct_data = ad.Zero
if type(ct[0]) is ad.Zero:
ct_data = ad.Zero(data)
else:
if data.aval.shape[0] == 1: # scalar
ct_data = csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0]
ct_data = jnp.inner(ct[0], ct_data)
else:
row, col =csr_to_coo(indices, indptr)
row, col = csr_to_coo(indices, indptr)
ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row]

return ct_data, indices, indptr, vector
Expand All @@ -175,7 +174,7 @@ def csrmv_taichi(
*,
shape: Tuple[int, int],
transpose: bool = False,
):
) -> jax.Array:
"""Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm.
This function supports JAX transformations, including `jit()`, `grad()`,
Expand Down Expand Up @@ -239,12 +238,11 @@ def csrmv_taichi(
vector,
outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)],
transpose = transpose,
shape=shape
)
shape=shape)

# transpose
_event_csr_matvec_transpose_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_cpu_transpose,
gpu_kernel=_sparse_csr_matvec_gpu_transpose)
_event_csr_matvec_transpose_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_transpose_cpu,
gpu_kernel=_sparse_csr_matvec_transpose_gpu)
_event_csr_matvec_transpose_p.def_jvp_rule(_sparse_csr_matvec_jvp)
_event_csr_matvec_transpose_p.def_transpose_rule(_sparse_csr_matvec_transpose)

Expand Down
Loading

0 comments on commit 05cf134

Please sign in to comment.