diff --git a/brainpy/_src/math/jitconn/_matvec_taichi.py b/brainpy/_src/math/jitconn/_matvec_taichi.py index df1aebff8..580c7260e 100644 --- a/brainpy/_src/math/jitconn/_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/_matvec_taichi.py @@ -45,12 +45,13 @@ def _mv_prob_homo_outdim_parallel_cpu( num_col = shape[1] weight_value = weight[0] clen_value = clen[0] + seed_value = seed[0] ti.loop_config(serialize=True) for i_col in range(num_col): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) result = ti.f32(0.) @@ -76,22 +77,27 @@ def _mv_prob_homo_outdim_parallel_gpu( num_col = shape[1] weight_value = weight[0] clen_value = clen[0] + seed_value = seed[0] + avg_num_uniform = ti.i32((clen_value + 1) /2) - for i_col in range(num_col): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + for i in range(num_col * 32): + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) result = ti.f32(0.) - + + i_col = i >> 5 + index = i & 31 + s1, s2, s3, b, result = random_generator(s1, s2, s3, b) - i_row = uniform_int_distribution(result, 1, clen_value) + i_row = uniform_int_distribution(result, 1, clen_value) + avg_num_uniform * index v = vector[i_col] * weight_value while i_row < num_row: s1, s2, s3, b, result = random_generator(s1, s2, s3, b) out[i_row] += uniform_int_distribution(result, 1, clen_value) * v s1, s2, s3, b, result = random_generator(s1, s2, s3, b) - i_row += uniform_int_distribution(result, 1, clen_value) + i_row += uniform_int_distribution(result, 1, clen_value) * 32 @ti.kernel @@ -107,12 +113,13 @@ def _mv_prob_homo_cpu( num_col = shape[1] weight_value = weight[0] clen_value = clen[0] - + seed_value = seed[0] + ti.loop_config(serialize=True) for i_row in range(num_row): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) r = 0. result = ti.f32(0.) @@ -138,22 +145,27 @@ def _mv_prob_homo_gpu( num_col = shape[1] weight_value = weight[0] clen_value = clen[0] + seed_value = seed[0] + avg_num_uniform = ti.i32((clen_value + 1) /2) - for i_row in range(num_row): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + for i in range(num_row * 32): + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) r = 0. result = ti.f32(0.) + + i_row = i >> 5 + index = i & 31 s1, s2, s3, b, result = random_generator(s1, s2, s3, b) - i_col = uniform_int_distribution(result, 1, clen_value) + i_col = uniform_int_distribution(result, 1, clen_value) + avg_num_uniform * index while i_col < num_col: r += vector[i_col] s1, s2, s3, b, result = random_generator(s1, s2, s3, b) - i_col += uniform_int_distribution(result, 1, clen_value) - out[i_row] = r * weight_value + i_col += uniform_int_distribution(result, 1, clen_value) * 32 + out[i_row] += r * weight_value def _mv_prob_homo_jvp( primals, tangents, *, outs, shape, transpose, outdim_parallel, conn_prob @@ -374,12 +386,13 @@ def _mv_prob_uniform_outdim_parallel_cpu( clen_value = clen[0] w_min_value = w_min[0] w_max_value = w_max[0] + seed_value = seed[0] ti.loop_config(serialize=True) for i_col in range(num_col): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) result = ti.f32(0.) @@ -406,21 +419,26 @@ def _mv_prob_uniform_outdim_parallel_gpu( clen_value = clen[0] w_min_value = w_min[0] w_max_value = w_max[0] + seed_value = seed[0] + avg_num_uniform = ti.i32((clen_value + 1) /2) - for i_col in range(num_col): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + for i in range(num_col * 32): + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) result = ti.f32(0.) + + i_col = i >> 5 + index = i & 31 s1, s2, s3, b, result = random_generator(s1, s2, s3, b) - i_row = uniform_int_distribution(result, 1, clen_value) + i_row = uniform_int_distribution(result, 1, clen_value) + avg_num_uniform * index while i_row < num_row: s1, s2, s3, b, result = random_generator(s1, s2, s3, b) out[i_row] += uniform_real_distribution(result, w_min_value, w_max_value) * vector[i_col] s1, s2, s3, b, result = random_generator(s1, s2, s3, b) - i_row += uniform_int_distribution(result, 1, clen_value) + i_row += uniform_int_distribution(result, 1, clen_value) * 32 @ti.kernel def _mv_prob_uniform_cpu( @@ -437,11 +455,13 @@ def _mv_prob_uniform_cpu( clen_value = clen[0] w_min_value = w_min[0] w_max_value = w_max[0] + seed_value = seed[0] + ti.loop_config(serialize=True) for i_row in range(num_row): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) result = ti.f32(0.) r = 0. @@ -470,23 +490,28 @@ def _mv_prob_uniform_gpu( clen_value = clen[0] w_min_value = w_min[0] w_max_value = w_max[0] + seed_value = seed[0] + avg_num_uniform = ti.i32((clen_value + 1) /2) - for i_row in range(num_row): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + for i in range(num_row * 32): + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) result = ti.f32(0.) r = 0. + + i_row = i >> 5 + index = i & 31 s1, s2, s3, b, result = random_generator(s1, s2, s3, b) - i_col = uniform_int_distribution(result, 1, clen_value) + i_col = uniform_int_distribution(result, 1, clen_value) + avg_num_uniform * index while i_col < num_col: s1, s2, s3, b, result = random_generator(s1, s2, s3, b) r += uniform_real_distribution(result, w_min_value, w_max_value) * vector[i_col] s1, s2, s3, b, result = random_generator(s1, s2, s3, b) - i_col += uniform_int_distribution(result, 1, clen_value) - out[i_row] = r + i_col += uniform_int_distribution(result, 1, clen_value) * 32 + out[i_row] += r def _mv_prob_uniform_jvp( primals, tangents, *, outs, shape, transpose, outdim_parallel, conn_prob @@ -703,12 +728,13 @@ def _mv_prob_normal_outdim_parallel_cpu( clen_value = clen[0] w_mu_value = w_mu[0] w_sigma_value = w_sigma[0] + seed_value = seed[0] ti.loop_config(serialize=True) for i_col in range(num_col): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) r = 0. result1 = ti.f32(0.) @@ -720,7 +746,7 @@ def _mv_prob_normal_outdim_parallel_cpu( s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) s1, s2, s3, b, result2 = random_generator(s1, s2, s3, b) out[i_row] += normal_distribution(result1, result2, w_mu_value, w_sigma_value) * vector[i_col] - s1, s2, s3, b, result2 = random_generator(s1, s2, s3, b) + s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) i_row += uniform_int_distribution(result1, 1, clen_value) @ti.kernel @@ -738,24 +764,29 @@ def _mv_prob_normal_outdim_parallel_gpu( clen_value = clen[0] w_mu_value = w_mu[0] w_sigma_value = w_sigma[0] + seed_value = seed[0] + avg_num_uniform = ti.i32((clen_value + 1) /2) - for i_col in range(num_col): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + for i in range(num_col * 32): + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) r = 0. result1 = ti.f32(0.) result2 = ti.f32(0.) + + i_col = i >> 5 + index = i & 31 s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) - i_row = uniform_int_distribution(result1, 1, clen_value) + i_row = uniform_int_distribution(result1, 1, clen_value) + avg_num_uniform * index while i_row < num_row: s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) s1, s2, s3, b, result2 = random_generator(s1, s2, s3, b) out[i_row] += normal_distribution(result1, result2, w_mu_value, w_sigma_value) * vector[i_col] - s1, s2, s3, b, result2 = random_generator(s1, s2, s3, b) - i_row += uniform_int_distribution(result1, 1, clen_value) + s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) + i_row += uniform_int_distribution(result1, 1, clen_value) * 32 @ti.kernel def _mv_prob_normal_cpu( @@ -772,12 +803,13 @@ def _mv_prob_normal_cpu( clen_value = clen[0] w_mu_value = w_mu[0] w_sigma_value = w_sigma[0] + seed_value = seed[0] ti.loop_config(serialize=True) for i_row in range(num_row): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) r = 0. result1 = ti.f32(0.) @@ -789,7 +821,7 @@ def _mv_prob_normal_cpu( s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) s1, s2, s3, b, result2 = random_generator(s1, s2, s3, b) r += normal_distribution(result1, result2, w_mu_value, w_sigma_value) * vector[i_col] - s1, s2, s3, b, result2 = random_generator(s1, s2, s3, b) + s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) i_col += uniform_int_distribution(result1, 1, clen_value) out[i_row] = r @@ -808,25 +840,30 @@ def _mv_prob_normal_gpu( clen_value = clen[0] w_mu_value = w_mu[0] w_sigma_value = w_sigma[0] + seed_value = seed[0] + avg_num_uniform = ti.i32((clen_value + 1) /2) - for i_row in range(num_row): - s1 = seed[0] + 1 + ti.global_thread_idx() - s2 = seed[0] + 7 - s3 = seed[0] + 15 + for i in range(num_row * 32): + s1 = seed_value + 1 + ti.global_thread_idx() + s2 = seed_value + 7 + s3 = seed_value + 15 b = ti.u32(0) r = 0. result1 = ti.f32(0.) result2 = ti.f32(0.) + + i_row = i >> 5 + index = i & 31 s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) - i_col = uniform_int_distribution(result1, 1, clen_value) + i_col = uniform_int_distribution(result1, 1, clen_value) + avg_num_uniform * index while i_col < num_col: s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) s1, s2, s3, b, result2 = random_generator(s1, s2, s3, b) r += normal_distribution(result1, result2, w_mu_value, w_sigma_value) * vector[i_col] - s1, s2, s3, b, result2 = random_generator(s1, s2, s3, b) - i_col += uniform_int_distribution(result1, 1, clen_value) - out[i_row] = r + s1, s2, s3, b, result1 = random_generator(s1, s2, s3, b) + i_col += uniform_int_distribution(result1, 1, clen_value) * 32 + out[i_row] += r def _mv_prob_normal_jvp( primals, tangents, *, outs, shape, transpose, outdim_parallel, conn_prob diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py new file mode 100644 index 000000000..92def9be6 --- /dev/null +++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py @@ -0,0 +1,694 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('gpu') + +seed = 1234 + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 + ] +types = [ + 'homo', + 'uniform', + 'normal' + ] +transpose = [ + True, + False + ] +outdim_parallel = [ + True, + False, + ] +conn_prob = 0.1 +homo_data = 1. +w_low = 0. +w_high = 1. +w_mu = 0. +w_sigma = 0.1 + +print(bm.get_platform()) + +def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + + +def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel) + elif _type == 'uniform': + return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel) + elif _type == 'normal': + return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel) + else: + raise ValueError + + +def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel) + elif _type == 'uniform': + return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel) + elif _type == 'normal': + return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel) + else: + raise ValueError + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'speedup']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/jitconn_matvec_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/jitconn_matvec_gpu.csv', index=False) + +# if (bm.get_platform() == 'gpu'): +# for _s in s: +# for _p in p: +# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) +# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] +# df.to_csv('event_ell_gpu.csv', index=False) + + # df = pd.read_csv('event_ell_gpu.csv') + # for _s in s: + # for _p in p: + # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) + # # 找到对应的行 + # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time + # df.to_csv('event_ell_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py b/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py index b75d959b8..f04b3459c 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec_taichi.py @@ -288,7 +288,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non rng = bm.random.RandomState() vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - r1 = bm.jitconn.mv_prob_homo(vector, + r1 = bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=prob, shape=shape, @@ -296,7 +296,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non outdim_parallel=outdim_parallel, transpose=transpose)[0] - r2 = bm.jitconn.mv_prob_homo(vector, + r2 = bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=prob, shape=shape, @@ -305,7 +305,7 @@ def test_homo(self, shape, transpose, outdim_parallel, prob, homo_data, seed=Non transpose=transpose)[0] self.assertTrue(jnp.allclose(r1, r2)) - r2 = bm.jitconn.mv_prob_homo(vector, + r2 = bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=prob, shape=(shape[1], shape[0]), @@ -348,7 +348,7 @@ def test_homo_vmap(self, shape, transpose, outdim_parallel, prob, seed=None, x64 weights = bm.as_jax(rng.random(10)) f1 = jax.vmap( - lambda event, data: bm.jitconn.mv_prob_homo( + lambda event, data: bm.jitconn.mv_prob_homo_taichi( event, data, conn_prob=prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose @@ -393,7 +393,7 @@ def test_homo_grad(self, shape, transpose, outdim_parallel, prob, seed=None, x64 events = events.astype(float) f1 = jax.grad( - lambda event, data: bm.jitconn.mv_prob_homo( + lambda event, data: bm.jitconn.mv_prob_homo_taichi( event, data, conn_prob=prob, shape=shape,