From df8c0bfa2f243321fac2d771da91c7de702b6515 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 29 Jan 2024 13:35:14 +0800 Subject: [PATCH] [math] Fix pytest bugs --- .../_src/math/event/tests/test_event_csrmv.py | 28 +++++----- brainpy/_src/math/sparse/tests/test_csrmv.py | 53 ++++++++++++------- 2 files changed, 47 insertions(+), 34 deletions(-) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index 6f63b045..0598734a 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -8,7 +8,6 @@ import brainpy as bp import brainpy.math as bm -from .._csr_matvec import csrmv_brainpylib as brainpylib_csr_matvec seed = 1234 @@ -44,7 +43,8 @@ def test_homo(self, transpose, shape, homo_data): events = rng.random(shape[0] if transpose else shape[1]) < 0.1 heter_data = bm.ones(indices.shape) * homo_data - r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r1 = (events @ dense) if transpose else (dense @ events) r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose) assert (bm.allclose(r1, r2)) @@ -67,7 +67,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) @@ -75,7 +75,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data))) # vmap 'events' - f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr, + f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr, shape=shape, transpose=transpose)) f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr, shape=shape, transpose=transpose)) @@ -83,7 +83,7 @@ def test_homo_vmap(self, shape, transpose, homo_data): self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) + f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) vmap_data1 = bm.as_jax([homo_data] * 10) @@ -112,14 +112,14 @@ def test_homo_grad(self, shape, transpose, homo_data): dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape) # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + r1 = jax.grad(sum_op(bm.sparse.csrmv))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) r2 = jax.grad(sum_op(taichi_csr_matvec))( homo_data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) @@ -143,7 +143,7 @@ def test_heter(self, shape, transpose): events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 heter_data = bm.as_jax(rng.random(indices.shape)) - r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events, + r1 = bm.sparse.csrmv(heter_data, indices, indptr, events, shape=shape, transpose=transpose) r2 = taichi_csr_matvec(heter_data, indices, indptr, events, shape=shape, transpose=transpose) @@ -169,7 +169,7 @@ def test_heter_vmap(self, shape, transpose): # vmap 'data' events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events, + f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events, shape=shape, transpose=transpose)) @@ -178,7 +178,7 @@ def test_heter_vmap(self, shape, transpose): # vmap 'events' data = bm.as_jax(rng.random(indices.shape)) - f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr, + f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr, shape=shape, transpose=transpose)) f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr, shape=shape, transpose=transpose)) @@ -186,7 +186,7 @@ def test_heter_vmap(self, shape, transpose): self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data))) # vmap 'data' and 'events' - f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, + f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose)) f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose)) @@ -217,20 +217,20 @@ def test_heter_grad(self, shape, transpose): # grad 'data' data = bm.as_jax(rng.random(indices.shape)) - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( + r1 = jax.grad(sum_op(bm.sparse.csrmv))( data, indices, indptr, events, shape=shape, transpose=transpose) r2 = jax.grad(sum_op(taichi_csr_matvec))( data, indices, indptr, events, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'events' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( + r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( + r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( data, indices, indptr, events.astype(float), shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 123ca657..ec3ea3c5 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -7,7 +7,6 @@ import brainpy as bp import brainpy.math as bm -from .._csr_mv import csrmv_brainpylib as brainpylib_csr_matvec # bm.set_platform('gpu') @@ -80,7 +79,10 @@ def test_homo(self, transpose, shape, homo_data): vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - r1 = brainpylib_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + heter_data = bm.ones(indices.shape).value * homo_data + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r1 = (vector @ dense) if transpose else (dense @ vector) r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -106,12 +108,11 @@ def test_homo_vmap(self, transpose, shape, v): homo_data = bm.ones(10).value * v dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - f1 = partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) + f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(homo_data) - r2 = jax.vmap(f1)(homo_data) + r2 = jax.vmap(f2)(homo_data) self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() @@ -138,8 +139,11 @@ def test_homo_grad(self, transpose, shape, homo_data): # print('grad data start') # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - homo_data, indices, indptr, vector, shape=shape, transpose=transpose) + dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum() + if transpose else + ((dense * a) @ vector).sum()), + argnums=0) + r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(taichi_csr_matvec))( homo_data, indices, indptr, vector, shape=shape, transpose=transpose) @@ -155,15 +159,19 @@ def test_homo_grad(self, transpose, shape, homo_data): # print('grad vector start') # grad 'vector' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + dense_data = dense * homo_data + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum())) + r3 = dense_f2(vector) r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( - homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() + if transpose else + ((dense * a) @ v).sum()), + argnums=(0, 1)) + r5 = dense_f3(homo_data, vector) r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0])) @@ -190,7 +198,8 @@ def test_heter(self, transpose, shape): vector = rng.random(shape[0] if transpose else shape[1]) vector = bm.as_jax(vector) - r1 = brainpylib_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + r1 = (vector @ dense) if transpose else (dense @ vector) r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape) self.assertTrue(compare_with_nan_tolerance(r1, r2)) @@ -216,8 +225,7 @@ def test_heter_vmap(self, transpose, shape): dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data) - f1 = partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, vector=vector, - shape=shape, transpose=transpose) + f1 = lambda a: (a.T @ vector) if transpose else (a @ vector) f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector, shape=shape, transpose=transpose) r1 = jax.vmap(f1)(heter_data) @@ -242,21 +250,26 @@ def test_heter_grad(self, transpose, shape): vector = bm.as_jax(vector) # grad 'data' - r1 = jax.grad(sum_op(brainpylib_csr_matvec))( - heter_data, indices, indptr, vector, shape=shape, transpose=transpose) + dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()), + argnums=0) + r1 = dense_f1(dense_data) r2 = jax.grad(sum_op(taichi_csr_matvec))( heter_data, indices, indptr, vector, shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad 'vector' - r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()), + argnums=0) + r3 = dense_f2(vector) r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) - r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum() + if transpose else + ((dense * a) @ v).sum()), + argnums=(0, 1)) + r5 = dense_f3(heter_data, vector) r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))( heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) self.assertTrue(bm.allclose(r5[0], r6[0]))