Skip to content

Commit

Permalink
[math] Remove multiple results of event csrmv and csrmv
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 22, 2024
1 parent b24a544 commit 65305b2
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 311 deletions.
76 changes: 43 additions & 33 deletions brainpy/_src/math/event/_csr_matvec_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from brainpy._src.dependency_check import import_taichi
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.op_register import XLACustomOp
from brainpy._src.math.sparse._csr_mv_taichi import csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse._csr_mv_taichi import raw_csrmv_taichi as normal_csrmv_taichi
from brainpy._src.math.sparse._utils import csr_to_coo

ti = import_taichi()
Expand Down Expand Up @@ -333,13 +333,53 @@ def _event_csr_matvec_transpose(
ct_values = ad.Zero(values)
else:
if values.aval.shape[0] == 1: # scalar
ct_values = csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0]
ct_values = jnp.inner(ct[0], ct_values)
else: # heterogeneous values
row, col = csr_to_coo(indices, indptr)
ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row]
return ct_values, indices, indptr, events

def raw_csrmv_taichi(
data: Union[float, jax.Array],
indices: jax.Array,
indptr: jax.Array,
events: jax.Array,
*,
shape: Tuple[int, int],
transpose: bool = False
):
if transpose:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_bool_homo_p
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p

# computing
return prim(data,
indices,
indptr,
events,
outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)],
transpose=transpose,
shape=shape)

def csrmv_taichi(
data: Union[float, jax.Array],
Expand Down Expand Up @@ -419,37 +459,7 @@ def csrmv_taichi(
if indices.shape[0] == 0:
return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype)

if transpose:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_bool_homo_p
else:
prim = _event_csrmv_transpose_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_transpose_homo_p
else:
prim = _event_csrmv_transpose_heter_p
else:
if events.dtype == jnp.bool_:
if data.shape[0] == 1:
prim = _event_csrmv_bool_homo_p
else:
prim = _event_csrmv_bool_heter_p
else:
if data.shape[0] == 1:
prim = _event_csrmv_homo_p
else:
prim = _event_csrmv_heter_p

# computing
return prim(data,
indices,
indptr,
events,
outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)],
transpose=transpose,
shape=shape)
return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0]


def _define_op(cpu_kernel, gpu_kernel):
Expand Down
34 changes: 13 additions & 21 deletions brainpy/_src/math/event/tests/test_event_csrmv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ def func(*args, **kwargs):
return func


def sum_op2(op):
def func(*args, **kwargs):
r = op(*args, **kwargs)[0]
return r.sum()

return func


class Test_event_csr_matvec_taichi(parameterized.TestCase):
def __init__(self, *args, platform='cpu', **kwargs):
super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs)
Expand All @@ -53,7 +45,7 @@ def test_homo(self, transpose, shape, homo_data):
r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
r2 = bm.event.csrmv_taichi(homo_data, indices, indptr, events, shape=shape, transpose=transpose)

assert (bm.allclose(r1, r2[0]))
assert (bm.allclose(r1, r2))

bm.clear_buffer_memory()

Expand All @@ -78,15 +70,15 @@ def test_homo_vmap(self, shape, transpose, homo_data):
f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax([homo_data] * 10)
self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)[0]))
self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))

# vmap 'events'
f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr,
shape=shape, transpose=transpose))
f4 = jax.vmap(partial(bm.event.csrmv_taichi, homo_data, indices, indptr,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)[0]))
self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)))

# vmap 'data' and 'events'
f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose))
Expand All @@ -95,7 +87,7 @@ def test_homo_vmap(self, shape, transpose, homo_data):
vmap_data1 = bm.as_jax([homo_data] * 10)
vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2),
f6(vmap_data1, vmap_data2)[0]))
f6(vmap_data1, vmap_data2)))

bm.clear_buffer_memory()

Expand All @@ -120,14 +112,14 @@ def test_homo_grad(self, shape, transpose, homo_data):
# grad 'data'
r1 = jax.grad(sum_op(bm.event.csrmv))(
homo_data, indices, indptr, events, shape=shape, transpose=transpose)
r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))(
r2 = jax.grad(sum_op(bm.event.csrmv_taichi))(
homo_data, indices, indptr, events, shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r1, r2))

# grad 'events'
r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(
homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)(
r4 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=3)(
homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r3, r4))

Expand All @@ -154,7 +146,7 @@ def test_heter(self, shape, transpose):
r2 = bm.event.csrmv_taichi(heter_data, indices, indptr, events,
shape=shape, transpose=transpose)

assert (bm.allclose(r1, r2[0]))
assert (bm.allclose(r1, r2))

bm.clear_buffer_memory()

Expand All @@ -180,7 +172,7 @@ def test_heter_vmap(self, shape, transpose):
f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax(rng.random((10, indices.shape[0])))
self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)[0]))
self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))

# vmap 'events'
data = bm.as_jax(rng.random(indices.shape))
Expand All @@ -189,7 +181,7 @@ def test_heter_vmap(self, shape, transpose):
f4 = jax.vmap(partial(bm.event.csrmv_taichi, data, indices, indptr,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)[0]))
self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)))

# vmap 'data' and 'events'
f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee,
Expand All @@ -199,7 +191,7 @@ def test_heter_vmap(self, shape, transpose):
vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0])))
vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2
self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2),
f6(vmap_data1, vmap_data2)[0]))
f6(vmap_data1, vmap_data2)))

bm.clear_buffer_memory()

Expand All @@ -225,20 +217,20 @@ def test_heter_grad(self, shape, transpose):
data = bm.as_jax(rng.random(indices.shape))
r1 = jax.grad(sum_op(bm.event.csrmv))(
data, indices, indptr, events, shape=shape, transpose=transpose)
r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))(
r2 = jax.grad(sum_op(bm.event.csrmv_taichi))(
data, indices, indptr, events, shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r1, r2))

# grad 'events'
r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)(
r4 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=3)(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r3, r4))

r5 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
r6 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=(0, 3))(
r6 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=(0, 3))(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r5[0], r6[0]))
self.assertTrue(bm.allclose(r5[1], r6[1]))
Expand Down
60 changes: 36 additions & 24 deletions brainpy/_src/math/sparse/_csr_mv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,11 @@ def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1),


def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape):
return csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose)
return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose)


def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape):
return csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose)
return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose)


def _sparse_csr_matvec_transpose(
Expand All @@ -168,22 +168,51 @@ def _sparse_csr_matvec_transpose(
if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr):
raise ValueError("Cannot transpose with respect to sparse indices.")
if ad.is_undefined_primal(vector):
ct_vector = csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0]
ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0]
return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector)

else:
if type(ct[0]) is ad.Zero:
ct_data = ad.Zero(data)
else:
if data.aval.shape[0] == 1: # scalar
ct_data = csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0]
ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0]
ct_data = jnp.inner(ct[0], ct_data)
else:
row, col = csr_to_coo(indices, indptr)
ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row]

return ct_data, indices, indptr, vector

def raw_csrmv_taichi(
data: Union[float, jnp.ndarray, Array],
indices: Union[jnp.ndarray, Array],
indptr: Union[jnp.ndarray, Array],
vector: Union[jnp.ndarray, Array],
*,
shape: Tuple[int, int],
transpose: bool = False,
):
out_shape = shape[1] if transpose else shape[0]
if transpose:
if data.shape[0] == 1:
prim = _csr_matvec_transpose_homo_p
else:
prim = _csr_matvec_transpose_heter_p
else:
if data.shape[0] == 1:
prim = _csr_matvec_homo_p
else:
prim = _csr_matvec_heter_p

return prim(data,
indices,
indptr,
vector,
outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)],
transpose=transpose,
shape=shape)


def csrmv_taichi(
data: Union[float, jnp.ndarray, Array],
Expand Down Expand Up @@ -242,26 +271,9 @@ def csrmv_taichi(
raise ValueError('indices should be a 1D vector with integer type.')
if not jnp.issubdtype(indptr.dtype, jnp.integer):
raise ValueError('indptr should be a 1D vector with integer type.')
out_shape = shape[1] if transpose else shape[0]

if transpose:
if data.shape[0] == 1:
prim = _csr_matvec_transpose_homo_p
else:
prim = _csr_matvec_transpose_heter_p
else:
if data.shape[0] == 1:
prim = _csr_matvec_homo_p
else:
prim = _csr_matvec_heter_p

return prim(data,
indices,
indptr,
vector,
outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)],
transpose=transpose,
shape=shape)

return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0]



def _define_op(cpu_kernel, gpu_kernel):
Expand Down
Loading

0 comments on commit 65305b2

Please sign in to comment.