Skip to content

Commit

Permalink
Update benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 12, 2023
1 parent db7758f commit 0ef3782
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
p = [0.1, 0.2, 0.3, 0.4, 0.5]
values_type = ['homo', 'heter']
events_type = ['float']
transpose = [True, False]
transpose = [
True,
False
]
method = 'cusparse'

print(bm.get_platform())

Expand Down Expand Up @@ -182,7 +186,7 @@ def test_event_ell_gpu(s, p, values_type, events_type, transpose):
result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
time9 = time.time()

result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose))
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method))
# print('--------------------result1[0]------------------')
# print(result1[0])
# print('--------------------result2------------------')
Expand All @@ -198,26 +202,26 @@ def test_event_ell_gpu(s, p, values_type, events_type, transpose):
# assert bm.allclose(result1[0], result2)

time12 = time.time()
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method))
time13 = time.time()
# time.sleep(2)

time14 = time.time()
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method))
time15 = time.time()
# time.sleep(2)

time16 = time.time()
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method))
time17 = time.time()
# time.sleep(2)

time18 = time.time()
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method))
time19 = time.time()

time20 = time.time()
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True))
result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method))
time21 = time.time()

taichi_aot_time1 = (time1 - time0) * 1000
Expand Down

0 comments on commit 0ef3782

Please sign in to comment.