Skip to content

Commit

Permalink
[math] Fix pytest bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 29, 2024
1 parent 046dbea commit df8c0bf
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 34 deletions.
28 changes: 14 additions & 14 deletions brainpy/_src/math/event/tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import brainpy as bp
import brainpy.math as bm
from .._csr_matvec import csrmv_brainpylib as brainpylib_csr_matvec

seed = 1234

Expand Down Expand Up @@ -44,7 +43,8 @@ def test_homo(self, transpose, shape, homo_data):
events = rng.random(shape[0] if transpose else shape[1]) < 0.1
heter_data = bm.ones(indices.shape) * homo_data

r1 = brainpylib_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose)
dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
r1 = (events @ dense) if transpose else (dense @ events)
r2 = taichi_csr_matvec(homo_data, indices, indptr, events, shape=shape, transpose=transpose)

assert (bm.allclose(r1, r2))
Expand All @@ -67,23 +67,23 @@ def test_homo_vmap(self, shape, transpose, homo_data):

# vmap 'data'
events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events,
f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, events=events,
shape=shape, transpose=transpose))
f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax([homo_data] * 10)
self.assertTrue(bm.allclose(f1(vmap_data), f2(vmap_data)))

# vmap 'events'
f3 = jax.vmap(partial(brainpylib_csr_matvec, homo_data, indices, indptr,
f3 = jax.vmap(partial(bm.sparse.csrmv, homo_data, indices, indptr,
shape=shape, transpose=transpose))
f4 = jax.vmap(partial(taichi_csr_matvec, homo_data, indices, indptr,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)))

# vmap 'data' and 'events'
f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose))
f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee, shape=shape, transpose=transpose))
f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee, shape=shape, transpose=transpose))

vmap_data1 = bm.as_jax([homo_data] * 10)
Expand Down Expand Up @@ -112,14 +112,14 @@ def test_homo_grad(self, shape, transpose, homo_data):
dense_conn = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, indices, indptr, shape=shape)

# grad 'data'
r1 = jax.grad(sum_op(brainpylib_csr_matvec))(
r1 = jax.grad(sum_op(bm.sparse.csrmv))(
homo_data, indices, indptr, events, shape=shape, transpose=transpose)
r2 = jax.grad(sum_op(taichi_csr_matvec))(
homo_data, indices, indptr, events, shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r1, r2))

# grad 'events'
r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)(
r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(
homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)(
homo_data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
Expand All @@ -143,7 +143,7 @@ def test_heter(self, shape, transpose):
events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
heter_data = bm.as_jax(rng.random(indices.shape))

r1 = brainpylib_csr_matvec(heter_data, indices, indptr, events,
r1 = bm.sparse.csrmv(heter_data, indices, indptr, events,
shape=shape, transpose=transpose)
r2 = taichi_csr_matvec(heter_data, indices, indptr, events,
shape=shape, transpose=transpose)
Expand All @@ -169,7 +169,7 @@ def test_heter_vmap(self, shape, transpose):

# vmap 'data'
events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1
f1 = jax.vmap(partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, events=events,
f1 = jax.vmap(partial(bm.sparse.csrmv, indices=indices, indptr=indptr, events=events,
shape=shape, transpose=transpose))
f2 = jax.vmap(partial(taichi_csr_matvec, indices=indices, indptr=indptr, events=events,
shape=shape, transpose=transpose))
Expand All @@ -178,15 +178,15 @@ def test_heter_vmap(self, shape, transpose):

# vmap 'events'
data = bm.as_jax(rng.random(indices.shape))
f3 = jax.vmap(partial(brainpylib_csr_matvec, data, indices, indptr,
f3 = jax.vmap(partial(bm.sparse.csrmv, data, indices, indptr,
shape=shape, transpose=transpose))
f4 = jax.vmap(partial(taichi_csr_matvec, data, indices, indptr,
shape=shape, transpose=transpose))
vmap_data = bm.as_jax(rng.random((10, shape[0] if transpose else shape[1]))) < 0.1
self.assertTrue(bm.allclose(f3(vmap_data), f4(vmap_data)))

# vmap 'data' and 'events'
f5 = jax.vmap(lambda dd, ee: brainpylib_csr_matvec(dd, indices, indptr, ee,
f5 = jax.vmap(lambda dd, ee: bm.sparse.csrmv(dd, indices, indptr, ee,
shape=shape, transpose=transpose))
f6 = jax.vmap(lambda dd, ee: taichi_csr_matvec(dd, indices, indptr, ee,
shape=shape, transpose=transpose))
Expand Down Expand Up @@ -217,20 +217,20 @@ def test_heter_grad(self, shape, transpose):

# grad 'data'
data = bm.as_jax(rng.random(indices.shape))
r1 = jax.grad(sum_op(brainpylib_csr_matvec))(
r1 = jax.grad(sum_op(bm.sparse.csrmv))(
data, indices, indptr, events, shape=shape, transpose=transpose)
r2 = jax.grad(sum_op(taichi_csr_matvec))(
data, indices, indptr, events, shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r1, r2))

# grad 'events'
r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)(
r3 = jax.grad(sum_op(bm.sparse.csrmv), argnums=3)(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r3, r4))

r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))(
r5 = jax.grad(sum_op(bm.sparse.csrmv), argnums=(0, 3))(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))(
data, indices, indptr, events.astype(float), shape=shape, transpose=transpose)
Expand Down
53 changes: 33 additions & 20 deletions brainpy/_src/math/sparse/tests/test_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import brainpy as bp
import brainpy.math as bm
from .._csr_mv import csrmv_brainpylib as brainpylib_csr_matvec

# bm.set_platform('gpu')

Expand Down Expand Up @@ -80,7 +79,10 @@ def test_homo(self, transpose, shape, homo_data):
vector = rng.random(shape[0] if transpose else shape[1])
vector = bm.as_jax(vector)

r1 = brainpylib_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose)
heter_data = bm.ones(indices.shape).value * homo_data

dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
r1 = (vector @ dense) if transpose else (dense @ vector)
r2 = taichi_csr_matvec(homo_data, indices, indptr, vector, shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r1, r2))

Expand All @@ -106,12 +108,11 @@ def test_homo_vmap(self, transpose, shape, v):
homo_data = bm.ones(10).value * v
dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=shape))(heter_data)

f1 = partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, vector=vector,
shape=shape, transpose=transpose)
f1 = lambda a: (a.T @ vector) if transpose else (a @ vector)
f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector,
shape=shape, transpose=transpose)
r1 = jax.vmap(f1)(homo_data)
r2 = jax.vmap(f1)(homo_data)
r2 = jax.vmap(f2)(homo_data)
self.assertTrue(bm.allclose(r1, r2))

bm.clear_buffer_memory()
Expand All @@ -138,8 +139,11 @@ def test_homo_grad(self, transpose, shape, homo_data):

# print('grad data start')
# grad 'data'
r1 = jax.grad(sum_op(brainpylib_csr_matvec))(
homo_data, indices, indptr, vector, shape=shape, transpose=transpose)
dense_f1 = jax.grad(lambda a: ((vector @ (dense * a)).sum()
if transpose else
((dense * a) @ vector).sum()),
argnums=0)
r1 = dense_f1(homo_data)
r2 = jax.grad(sum_op(taichi_csr_matvec))(
homo_data, indices, indptr, vector, shape=shape, transpose=transpose)

Expand All @@ -155,15 +159,19 @@ def test_homo_grad(self, transpose, shape, homo_data):

# print('grad vector start')
# grad 'vector'
r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)(
homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
dense_data = dense * homo_data
dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()))
r3 = dense_f2(vector)
r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)(
homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)

self.assertTrue(bm.allclose(r3, r4))

r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))(
homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum()
if transpose else
((dense * a) @ v).sum()),
argnums=(0, 1))
r5 = dense_f3(homo_data, vector)
r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))(
homo_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r5[0], r6[0]))
Expand All @@ -190,7 +198,8 @@ def test_heter(self, transpose, shape):
vector = rng.random(shape[0] if transpose else shape[1])
vector = bm.as_jax(vector)

r1 = brainpylib_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape)
r1 = (vector @ dense) if transpose else (dense @ vector)
r2 = taichi_csr_matvec(heter_data, indices, indptr, vector, shape=shape)

self.assertTrue(compare_with_nan_tolerance(r1, r2))
Expand All @@ -216,8 +225,7 @@ def test_heter_vmap(self, transpose, shape):
dense_data = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr,
shape=shape))(heter_data)

f1 = partial(brainpylib_csr_matvec, indices=indices, indptr=indptr, vector=vector,
shape=shape, transpose=transpose)
f1 = lambda a: (a.T @ vector) if transpose else (a @ vector)
f2 = partial(taichi_csr_matvec, indices=indices, indptr=indptr, vector=vector,
shape=shape, transpose=transpose)
r1 = jax.vmap(f1)(heter_data)
Expand All @@ -242,21 +250,26 @@ def test_heter_grad(self, transpose, shape):
vector = bm.as_jax(vector)

# grad 'data'
r1 = jax.grad(sum_op(brainpylib_csr_matvec))(
heter_data, indices, indptr, vector, shape=shape, transpose=transpose)
dense_f1 = jax.grad(lambda a: ((vector @ a).sum() if transpose else (a @ vector).sum()),
argnums=0)
r1 = dense_f1(dense_data)
r2 = jax.grad(sum_op(taichi_csr_matvec))(
heter_data, indices, indptr, vector, shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r1, r2))

# grad 'vector'
r3 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=3)(
heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
dense_f2 = jax.grad(lambda v: ((v @ dense_data).sum() if transpose else (dense_data @ v).sum()),
argnums=0)
r3 = dense_f2(vector)
r4 = jax.grad(sum_op(taichi_csr_matvec), argnums=3)(
heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r3, r4))

r5 = jax.grad(sum_op(brainpylib_csr_matvec), argnums=(0, 3))(
heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
dense_f3 = jax.grad(lambda a, v: ((v @ (dense * a)).sum()
if transpose else
((dense * a) @ v).sum()),
argnums=(0, 1))
r5 = dense_f3(heter_data, vector)
r6 = jax.grad(sum_op(taichi_csr_matvec), argnums=(0, 3))(
heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)
self.assertTrue(bm.allclose(r5[0], r6[0]))
Expand Down

0 comments on commit df8c0bf

Please sign in to comment.