diff --git a/brainpy/_src/math/event/_csr_matvec_taichi.py b/brainpy/_src/math/event/_csr_matvec_taichi.py index 9be9c49d..2ee47d83 100644 --- a/brainpy/_src/math/event/_csr_matvec_taichi.py +++ b/brainpy/_src/math/event/_csr_matvec_taichi.py @@ -10,7 +10,7 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse._csr_mv_taichi import csrmv_taichi as normal_csrmv_taichi +from brainpy._src.math.sparse._csr_mv_taichi import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo ti = import_taichi() @@ -333,13 +333,53 @@ def _event_csr_matvec_transpose( ct_values = ad.Zero(values) else: if values.aval.shape[0] == 1: # scalar - ct_values = csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] ct_values = jnp.inner(ct[0], ct_values) else: # heterogeneous values row, col = csr_to_coo(indices, indptr) ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] return ct_values, indices, indptr, events +def raw_csrmv_taichi( + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False +): + if transpose: + if events.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csrmv_transpose_bool_homo_p + else: + prim = _event_csrmv_transpose_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csrmv_transpose_homo_p + else: + prim = _event_csrmv_transpose_heter_p + else: + if events.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csrmv_bool_homo_p + else: + prim = _event_csrmv_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csrmv_homo_p + else: + prim = _event_csrmv_heter_p + + # computing + return prim(data, + indices, + indptr, + events, + outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], + transpose=transpose, + shape=shape) def csrmv_taichi( data: Union[float, jax.Array], @@ -419,37 +459,7 @@ def csrmv_taichi( if indices.shape[0] == 0: return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - if transpose: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_bool_homo_p - else: - prim = _event_csrmv_transpose_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_homo_p - else: - prim = _event_csrmv_transpose_heter_p - else: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_bool_homo_p - else: - prim = _event_csrmv_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csrmv_homo_p - else: - prim = _event_csrmv_heter_p - - # computing - return prim(data, - indices, - indptr, - events, - outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], - transpose=transpose, - shape=shape) + return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] def _define_op(cpu_kernel, gpu_kernel): diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py index b759a478..c81aee7c 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv_taichi.py @@ -20,14 +20,6 @@ def func(*args, **kwargs): return func -def sum_op2(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - - class Test_event_csr_matvec_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_event_csr_matvec_taichi, self).__init__(*args, **kwargs) @@ -53,7 +45,7 @@ def test_homo(self, transpose, shape, homo_data): r1 = bm.event.csrmv(homo_data, indices, indptr, events, shape=shape, transpose=transpose) r2 = bm.event.csrmv_taichi(homo_data, indices, indptr, events, shape=shape, transpose=transpose) - assert (bm.allclose(r1, r2[0])) + assert (bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -78,7 +70,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)[0])) + self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' f3 = jax.vmap(partial(bm.event.csrmv, homo_data, indices, indptr, @@ -86,7 +78,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): f4 = jax.vmap(partial(bm.event.csrmv_taichi, homo_data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)[0])) + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) @@ -95,7 +87,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): vmap_data1 = bm.as_jax([homo_data] * 10) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2)[0])) + f6(vmap_data1, vmap_data2))) bm.clear_buffer_memory() @@ -120,14 +112,14 @@ def test_homo_grad(self, shape, transpose, homo_data): # grad 'data' r1 = jax.grad(sum_op(bm.event.csrmv))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))( + r2 = jax.grad(sum_op(bm.event.csrmv_taichi))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -154,7 +146,7 @@ def test_heter(self, shape, transpose): r2 = bm.event.csrmv_taichi(heter_data, indices, indptr, events, shape=shape, transpose=transpose) - assert (bm.allclose(r1, r2[0])) + assert (bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -180,7 +172,7 @@ def test_heter_vmap(self, shape, transpose): f2 = jax.vmap(partial(bm.event.csrmv_taichi, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)[0])) + self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' data = bm.as_jax(rng.random(indices.shape)) @@ -189,7 +181,7 @@ def test_heter_vmap(self, shape, transpose): f4 = jax.vmap(partial(bm.event.csrmv_taichi, data, indices, indptr, shape=shape, transpose=transpose)) vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1 - self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)[0])) + self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' f5 = jax.vmap(lambda dd, ee: bm.event.csrmv(dd, indices, indptr, ee, @@ -199,7 +191,7 @@ def test_heter_vmap(self, shape, transpose): vmap_data1 = bm.as_jax(rng.random((10, indices.shape[0]))) vmap_data2 = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.2 self.assertTrue(bm.allclose(f5(vmap_data1, vmap_data2), - f6(vmap_data1, vmap_data2)[0])) + f6(vmap_data1, vmap_data2))) bm.clear_buffer_memory() @@ -225,20 +217,20 @@ def test_heter_grad(self, shape, transpose): data = bm.as_jax(rng.random(indices.shape)) r1 = jax.grad(sum_op(bm.event.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.event.csrmv_taichi))( + r2 = jax.grad(sum_op(bm.event.csrmv_taichi))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' r3 = jax.grad(sum_op(bm.event.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(bm.event.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.event.csrmv_taichi), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) diff --git a/brainpy/_src/math/sparse/_csr_mv_taichi.py b/brainpy/_src/math/sparse/_csr_mv_taichi.py index cd09af08..5038e372 100644 --- a/brainpy/_src/math/sparse/_csr_mv_taichi.py +++ b/brainpy/_src/math/sparse/_csr_mv_taichi.py @@ -155,11 +155,11 @@ def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) + return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) + return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) def _sparse_csr_matvec_transpose( @@ -168,7 +168,7 @@ def _sparse_csr_matvec_transpose( if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): raise ValueError("Cannot transpose with respect to sparse indices.") if ad.is_undefined_primal(vector): - ct_vector = csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) else: @@ -176,7 +176,7 @@ def _sparse_csr_matvec_transpose( ct_data = ad.Zero(data) else: if data.aval.shape[0] == 1: # scalar - ct_data = csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] + ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] ct_data = jnp.inner(ct[0], ct_data) else: row, col = csr_to_coo(indices, indptr) @@ -184,6 +184,35 @@ def _sparse_csr_matvec_transpose( return ct_data, indices, indptr, vector +def raw_csrmv_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + out_shape = shape[1] if transpose else shape[0] + if transpose: + if data.shape[0] == 1: + prim = _csr_matvec_transpose_homo_p + else: + prim = _csr_matvec_transpose_heter_p + else: + if data.shape[0] == 1: + prim = _csr_matvec_homo_p + else: + prim = _csr_matvec_heter_p + + return prim(data, + indices, + indptr, + vector, + outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], + transpose=transpose, + shape=shape) + def csrmv_taichi( data: Union[float, jnp.ndarray, Array], @@ -242,26 +271,9 @@ def csrmv_taichi( raise ValueError('indices should be a 1D vector with integer type.') if not jnp.issubdtype(indptr.dtype, jnp.integer): raise ValueError('indptr should be a 1D vector with integer type.') - out_shape = shape[1] if transpose else shape[0] - - if transpose: - if data.shape[0] == 1: - prim = _csr_matvec_transpose_homo_p - else: - prim = _csr_matvec_transpose_heter_p - else: - if data.shape[0] == 1: - prim = _csr_matvec_homo_p - else: - prim = _csr_matvec_heter_p - - return prim(data, - indices, - indptr, - vector, - outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], - transpose=transpose, - shape=shape) + + return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] + def _define_op(cpu_kernel, gpu_kernel): diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py index 2b3d7b5b..fed665c8 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py @@ -21,13 +21,6 @@ def func(*args, **kwargs): return func -def sum_op2(op): - def func(*args, **kwargs): - r = op(*args, **kwargs)[0] - return r.sum() - - return func - def compare_with_nan_tolerance(a, b, tol=1e-8): """ @@ -62,219 +55,6 @@ def compare_with_nan_tolerance(a, b, tol=1e-8): vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') -### MANUAL TESTS ### -# transposes = [True, False] -# homo_datas = [-1., 0., 0.1, 1.] -# shapes = [(100, 200), (10, 1000), (2, 2000)] -# -# -# def test_homo(transpose, shape, homo_data): -# print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# conn = bp.conn.FixedProb(0.1) -# -# # matrix -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# # vector -# rng = bm.random.RandomState(123) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = bm.sparse.csrmv_taichi(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# assert (bm.allclose(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_homo_vmap(transpose, shape, homo_data): -# print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# heter_data = bm.ones((10, indices.shape[0])).value * homo_data -# homo_data = bm.ones(10).value * homo_data -# dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) -# -# f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# r1 = jax.vmap(f1)(homo_data) -# r2 = jax.vmap(f1)(homo_data) -# assert (bm.allclose(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_homo_grad(transpose, shape, homo_data): -# print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, -# indices, -# indptr, -# shape=shape) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# # print('grad data start') -# # grad 'data' -# r1 = jax.grad(sum_op(vector_csr_matvec))( -# homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( -# homo_data, indices, indptr, vector, shape=shape, transpose=transpose) -# -# # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, -# # shape=shape, transpose=transpose).sum(), -# # argnums=0) -# # csr_f2 = jax.grad(lambda a: bm.sparse.csrmv_taichi(a, indices, indptr, vector, -# # shape=shape, transpose=transpose)[0].sum(), -# # argnums=0) -# # r1 = csr_f1(homo_data) -# # r2 = csr_f2(homo_data) -# assert (bm.allclose(r1, r2)) -# -# # print('grad vector start') -# # grad 'vector' -# r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( -# homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( -# homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# # csr_f3 = jax.grad(lambda v: vector_csr_matvec(homo_data, indices, indptr, v, -# # shape=shape, transpose=transpose).sum()) -# # csr_f4 = jax.grad(lambda v: bm.sparse.csrmv_taichi(homo_data, indices, indptr, v, -# # shape=shape, transpose=transpose)[0].sum()) -# # r3 = csr_f3(vector) -# # r4 = csr_f4(vector) -# assert (bm.allclose(r3, r4)) -# -# # csr_f5 = jax.grad(lambda a, v: vector_csr_matvec(a, indices, indptr, v, -# # shape=shape, transpose=transpose).sum(), -# # argnums=(0, 1)) -# # csr_f6 = jax.grad(lambda a, v: bm.sparse.csrmv_taichi(a, indices, indptr, v, -# # shape=shape, transpose=transpose)[0].sum(), -# # argnums=(0, 1)) -# # r5 = csr_f5(homo_data, vector) -# # r6 = csr_f6(homo_data, vector) -# # assert(bm.allclose(r5[0], r6[0])) -# # assert(bm.allclose(r5[1], r6[1])) -# -# bm.clear_buffer_memory() -# -# -# def test_heter(transpose, shape): -# print(f'test_heter: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# heter_data = bm.as_jax(rng.random(indices.shape)) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) -# r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape) -# # bm.nan_to_num(r1) -# # bm.nan_to_num(r2[0]) -# # print(r1) -# # print(r1 - r2[0]) -# assert (compare_with_nan_tolerance(r1, r2[0])) -# -# bm.clear_buffer_memory() -# -# -# def test_heter_vmap(transpose, shape): -# print(f'test_heter_vmap: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# heter_data = rng.random((10, indices.shape[0])) -# heter_data = bm.as_jax(heter_data) -# dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, -# shape=shape))(heter_data) -# -# f1 = partial(vector_csr_matvec, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# f2 = partial(bm.sparse.csrmv_taichi, indices=indices, indptr=indptr, vector=vector, -# shape=shape, transpose=transpose) -# r1 = jax.vmap(f1)(heter_data) -# r2 = jax.vmap(f2)(heter_data) -# assert (bm.allclose(r1, r2[0])) -# -# -# def test_heter_grad(transpose, shape): -# print(f'test_heter_grad: transpose = {transpose} shape = {shape}') -# rng = bm.random.RandomState() -# conn = bp.conn.FixedProb(0.1) -# -# indices, indptr = conn(*shape).require('pre2post') -# indices = bm.as_jax(indices) -# indptr = bm.as_jax(indptr) -# heter_data = rng.random(indices.shape) -# heter_data = bm.as_jax(heter_data) -# dense_data = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) -# vector = rng.random(shape[0] if transpose else shape[1]) -# vector = bm.as_jax(vector) -# -# # grad 'data' -# r1 = jax.grad(sum_op(vector_csr_matvec))( -# heter_data, indices, indptr, vector, shape=shape, transpose=transpose) -# r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( -# heter_data, indices, indptr, vector, shape=shape, transpose=transpose) -# assert (bm.allclose(r1, r2)) -# -# # grad 'vector' -# r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# assert (bm.allclose(r3, r4)) -# -# r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( -# heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) -# assert (bm.allclose(r5[0], r6[0])) -# assert (bm.allclose(r5[1], r6[1])) -# -# bm.clear_buffer_memory() -# -# def test_all(): -# # for transpose in transposes: -# # for shape in shapes: -# # for homo_data in homo_datas: -# # test_homo(transpose, shape, homo_data) -# # test_homo_vmap(transpose, shape, homo_data) -# # test_homo_grad(transpose, shape, homo_data) -# -# for transpose in transposes: -# for shape in shapes: -# test_heter(transpose, shape) -# test_heter_vmap(transpose, shape) -# test_heter_grad(transpose, shape) -# test_all() - -# PYTEST class Test_csrmv_taichi(parameterized.TestCase): def __init__(self, *args, platform='cpu', **kwargs): super(Test_csrmv_taichi, self).__init__(*args, **kwargs) @@ -302,7 +82,7 @@ def test_homo(self, transpose, shape, homo_data): r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) r2 = bm.sparse.csrmv_taichi(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - self.assertTrue(bm.allclose(r1, r2[0])) + self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -332,7 +112,7 @@ def test_homo_vmap(self, transpose, shape, v): shape=shape, transpose=transpose) r1 = jax.vmap(f1)(homo_data) r2 = jax.vmap(f1)(homo_data) - self.assertTrue(bm.allclose(r1, r2[0])) + self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -360,7 +140,7 @@ def test_homo_grad(self, transpose, shape, homo_data): # grad 'data' r1 = jax.grad(sum_op(vector_csr_matvec))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( + r2 = jax.grad(sum_op(bm.sparse.csrmv_taichi))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) # csr_f1 = jax.grad(lambda a: vector_csr_matvec(a, indices, indptr, vector, @@ -377,14 +157,14 @@ def test_homo_grad(self, transpose, shape, homo_data): # grad 'vector' r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1])) @@ -413,10 +193,7 @@ def test_heter(self, transpose, shape): r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape) - print(r1) - print(r2[0]) - - self.assertTrue(compare_with_nan_tolerance(r1, r2[0])) + self.assertTrue(compare_with_nan_tolerance(r1, r2)) bm.clear_buffer_memory() @@ -445,7 +222,7 @@ def test_heter_vmap(self, transpose, shape): shape=shape, transpose=transpose) r1 = jax.vmap(f1)(heter_data) r2 = jax.vmap(f2)(heter_data) - self.assertTrue(compare_with_nan_tolerance(r1, r2[0])) + self.assertTrue(compare_with_nan_tolerance(r1, r2)) @parameterized.product( transpose=[True, False], @@ -467,20 +244,20 @@ def test_heter_grad(self, transpose, shape): # grad 'data' r1 = jax.grad(sum_op(vector_csr_matvec))( heter_data, indices, indptr, vector, shape=shape, transpose=transpose) - r2 = jax.grad(sum_op2(bm.sparse.csrmv_taichi))( + r2 = jax.grad(sum_op(bm.sparse.csrmv_taichi))( heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' r3 = jax.grad(sum_op(vector_csr_matvec), argnums=3)( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r4 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( + r4 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=3)( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) r5 = jax.grad(sum_op(vector_csr_matvec), argnums=(0, 3))( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=(0, 3))( + r6 = jax.grad(sum_op(bm.sparse.csrmv_taichi), argnums=(0, 3))( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) self.assertTrue(bm.allclose(r5[1], r6[1]))