Skip to content

Commit

Permalink
Fix test bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 4, 2023
1 parent ace8282 commit 6101133
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 18 deletions.
14 changes: 8 additions & 6 deletions brainpy/_src/math/event/tests/test_event_csrmv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

# bm.set_platform('cpu')

seed = 1234

def sum_op(op):
def func(*args, **kwargs):
r = op(*args, **kwargs)
Expand Down Expand Up @@ -254,7 +256,7 @@ def __init__(self, *args, platform='cpu', **kwargs):
)
def test_homo(self, transpose, shape, homo_data):
print(f'test_homo: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')
rng = bm.random.RandomState()
rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
events = rng.random(shape[0] if transpose else shape[1]) < 0.1
heter_data = bm.ones(indices.shape) * homo_data
Expand All @@ -277,7 +279,7 @@ def test_homo(self, transpose, shape, homo_data):
def test_homo_vmap(self, shape, transpose, homo_data):
print(f'test_homo_vamp: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')

rng = bm.random.RandomState()
rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')

# vmap 'data'
Expand Down Expand Up @@ -319,7 +321,7 @@ def test_homo_vmap(self, shape, transpose, homo_data):
def test_homo_grad(self, shape, transpose, homo_data):
print(f'test_homo_grad: shape = {shape}, transpose = {transpose}, homo_data = {homo_data}')

rng = bm.random.RandomState()
rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
Expand Down Expand Up @@ -351,7 +353,7 @@ def test_homo_grad(self, shape, transpose, homo_data):
)
def test_heter(self, shape, transpose):
print(f'test_heter: shape = {shape}, transpose = {transpose}')
rng = bm.random.RandomState()
rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
Expand All @@ -377,7 +379,7 @@ def test_heter(self, shape, transpose):
def test_heter_vmap(self, shape, transpose):
print(f'test_heter_vamp: shape = {shape}, transpose = {transpose}')

rng = bm.random.RandomState()
rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
Expand Down Expand Up @@ -422,7 +424,7 @@ def test_heter_vmap(self, shape, transpose):
def test_heter_grad(self, shape, transpose):
print(f'test_heter_grad: shape = {shape}, transpose = {transpose}')

rng = bm.random.RandomState()
rng = bm.random.RandomState(seed=seed)
indices, indptr = bp.conn.FixedProb(0.4)(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
Expand Down
29 changes: 17 additions & 12 deletions brainpy/_src/math/sparse/tests/test_csrmv_taichi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

# bm.set_platform('gpu')

seed = 1234

def sum_op(op):
def func(*args, **kwargs):
r = op(*args, **kwargs)
Expand Down Expand Up @@ -291,14 +293,14 @@ def __init__(self, *args, platform='cpu', **kwargs):
)
def test_homo(self, transpose, shape, homo_data):
print(f'test_homo: transpose = {transpose} shape = {shape}, homo_data = {homo_data}')
conn = bp.conn.FixedProb(0.1)
conn = bp.conn.FixedProb(0.3)

# matrix
indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
indptr = bm.as_jax(indptr)
# vector
rng = bm.random.RandomState(123)
rng = bm.random.RandomState(seed=seed)
vector = rng.random(shape[0] if transpose else shape[1])
vector = bm.as_jax(vector)

Expand All @@ -315,8 +317,8 @@ def test_homo(self, transpose, shape, homo_data):
)
def test_homo_vmap(self, transpose, shape, v):
print(f'test_homo_vmap: transpose = {transpose} shape = {shape}, v = {v}')
rng = bm.random.RandomState()
conn = bp.conn.FixedProb(0.1)
rng = bm.random.RandomState(seed=seed)
conn = bp.conn.FixedProb(0.3)

indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
Expand Down Expand Up @@ -345,8 +347,8 @@ def test_homo_vmap(self, transpose, shape, v):
)
def test_homo_grad(self, transpose, shape, homo_data):
print(f'test_homo_grad: transpose = {transpose} shape = {shape}, homo_data = {homo_data}')
rng = bm.random.RandomState()
conn = bp.conn.FixedProb(0.1)
rng = bm.random.RandomState(seed=seed)
conn = bp.conn.FixedProb(0.3)

indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
Expand Down Expand Up @@ -399,8 +401,8 @@ def test_homo_grad(self, transpose, shape, homo_data):
)
def test_heter(self, transpose, shape):
print(f'test_homo: transpose = {transpose} shape = {shape}')
rng = bm.random.RandomState()
conn = bp.conn.FixedProb(0.1)
rng = bm.random.RandomState(seed=seed)
conn = bp.conn.FixedProb(0.3)

indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
Expand All @@ -415,6 +417,9 @@ def test_heter(self, transpose, shape):
r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape)
r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape)

print(r1)
print(r2[0])

self.assertTrue(compare_with_nan_tolerance(r1, r2[0]))

bm.clear_buffer_memory()
Expand All @@ -424,8 +429,8 @@ def test_heter(self, transpose, shape):
shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)]
)
def test_heter_vmap(self, transpose, shape):
rng = bm.random.RandomState()
conn = bp.conn.FixedProb(0.1)
rng = bm.random.RandomState(seed=seed)
conn = bp.conn.FixedProb(0.3)

indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
Expand All @@ -451,8 +456,8 @@ def test_heter_vmap(self, transpose, shape):
shape=[(200, 200), (200, 100), (10, 1000), (2, 2000)]
)
def test_heter_grad(self, transpose, shape):
rng = bm.random.RandomState()
conn = bp.conn.FixedProb(0.1)
rng = bm.random.RandomState(seed=seed)
conn = bp.conn.FixedProb(0.3)

indices, indptr = conn(*shape).require('pre2post')
indices = bm.as_jax(indices)
Expand Down

0 comments on commit 6101133

Please sign in to comment.