diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py index 606715e91..ea67668c0 100644 --- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py +++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py @@ -18,7 +18,11 @@ p = [0.1, 0.2, 0.3, 0.4, 0.5] values_type = ['homo', 'heter'] events_type = ['float'] -transpose = [True, False] +transpose = [ + True, + False + ] +method = 'cusparse' print(bm.get_platform()) @@ -182,7 +186,7 @@ def test_event_ell_gpu(s, p, values_type, events_type, transpose): result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) time9 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) + result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) # print('--------------------result1[0]------------------') # print(result1[0]) # print('--------------------result2------------------') @@ -198,26 +202,26 @@ def test_event_ell_gpu(s, p, values_type, events_type, transpose): # assert bm.allclose(result1[0], result2) time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True)) + result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) time13 = time.time() # time.sleep(2) time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True)) + result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) time15 = time.time() # time.sleep(2) time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True)) + result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) time17 = time.time() # time.sleep(2) time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True)) + result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) time19 = time.time() time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=True)) + result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) time21 = time.time() taichi_aot_time1 = (time1 - time0) * 1000