Skip to content

Commit

Permalink
New benchmark method
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jan 12, 2024
1 parent a3dd5a5 commit 6c3eddd
Show file tree
Hide file tree
Showing 8 changed files with 1,520 additions and 2,735 deletions.
559 changes: 119 additions & 440 deletions brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py

Large diffs are not rendered by default.

353 changes: 118 additions & 235 deletions brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pandas as pd
import taichi as ti

bm.set_platform('gpu')
bm.set_platform('cpu')

seed = 1234

Expand All @@ -25,6 +25,7 @@
37500,
50000
]
bool_event = False
types = [
'homo',
'uniform',
Expand All @@ -45,6 +46,10 @@
w_mu = 0.
w_sigma = 0.1

ITERATION = 100
if bm.get_platform() == 'cpu':
ITERATION = 10

print(bm.get_platform())

def sum_op(op):
Expand All @@ -56,39 +61,57 @@ def func(*args, **kwargs):

@partial(jax.jit, static_argnums=(4, 5, 6))
def jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
return jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)(
vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
r = 0
for i in range(ITERATION):
r += jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)(
vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
return r

@partial(jax.jit, static_argnums=(4, 5, 6))
def jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel):
return jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)(
vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
r = 0
for i in range(ITERATION):
r += jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)(
vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
return r

@partial(jax.jit, static_argnums=(5, 6, 7))
def jitconn_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
return jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)(
vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
r = 0
for i in range(ITERATION):
r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)(
vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
return r

@partial(jax.jit, static_argnums=(5, 6, 7))
def jitconn_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel):
return jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)(
vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
r = 0
for i in range(ITERATION):
r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)(
vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
return r

@partial(jax.jit, static_argnums=(5, 6, 7))
def jitconn_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
return jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)(
vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
r = 0
for i in range(ITERATION):
r += jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)(
vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
return r

@partial(jax.jit, static_argnums=(5, 6, 7))
def jitconn_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel):
return jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)(
vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
r = 0
for i in range(ITERATION):
r += jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)(
vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel
)
return r

def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel):
rng = bm.random.RandomState(seed=seed)
Expand Down Expand Up @@ -448,11 +471,11 @@ def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel):
print('taichi_aot_3: ', taichi_aot_time3, 'ms')
print('taichi_aot_4: ', taichi_aot_time4, 'ms')
print('taichi_aot_5: ', taichi_aot_time5, 'ms')
print('brainpylib_gpu_1: ', brainpy_time1, 'ms')
print('brainpylib_gpu_2: ', brainpy_time2, 'ms')
print('brainpylib_gpu_3: ', brainpy_time3, 'ms')
print('brainpylib_gpu_4: ', brainpy_time4, 'ms')
print('brainpylib_gpu_5: ', brainpy_time5, 'ms')
print('brainpylib_1: ', brainpy_time1, 'ms')
print('brainpylib_2: ', brainpy_time2, 'ms')
print('brainpylib_3: ', brainpy_time3, 'ms')
print('brainpylib_4: ', brainpy_time4, 'ms')
print('brainpylib_5: ', brainpy_time5, 'ms')
# assert(jnp.allclose(result1[0], result2))

speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
Expand Down Expand Up @@ -543,11 +566,11 @@ def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel):
print('taichi_aot_3: ', taichi_aot_time3, 'ms')
print('taichi_aot_4: ', taichi_aot_time4, 'ms')
print('taichi_aot_5: ', taichi_aot_time5, 'ms')
print('brainpylib_gpu_1: ', brainpy_time1, 'ms')
print('brainpylib_gpu_2: ', brainpy_time2, 'ms')
print('brainpylib_gpu_3: ', brainpy_time3, 'ms')
print('brainpylib_gpu_4: ', brainpy_time4, 'ms')
print('brainpylib_gpu_5: ', brainpy_time5, 'ms')
print('brainpylib_1: ', brainpy_time1, 'ms')
print('brainpylib_2: ', brainpy_time2, 'ms')
print('brainpylib_3: ', brainpy_time3, 'ms')
print('brainpylib_4: ', brainpy_time4, 'ms')
print('brainpylib_5: ', brainpy_time5, 'ms')
# assert(jnp.allclose(result1[0], result2))

speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
Expand Down Expand Up @@ -638,11 +661,11 @@ def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel):
print('taichi_aot_3: ', taichi_aot_time3, 'ms')
print('taichi_aot_4: ', taichi_aot_time4, 'ms')
print('taichi_aot_5: ', taichi_aot_time5, 'ms')
print('brainpylib_gpu_1: ', brainpy_time1, 'ms')
print('brainpylib_gpu_2: ', brainpy_time2, 'ms')
print('brainpylib_gpu_3: ', brainpy_time3, 'ms')
print('brainpylib_gpu_4: ', brainpy_time4, 'ms')
print('brainpylib_gpu_5: ', brainpy_time5, 'ms')
print('brainpylib_1: ', brainpy_time1, 'ms')
print('brainpylib_2: ', brainpy_time2, 'ms')
print('brainpylib_3: ', brainpy_time3, 'ms')
print('brainpylib_4: ', brainpy_time4, 'ms')
print('brainpylib_5: ', brainpy_time5, 'ms')
# assert(jnp.allclose(result1[0], result2))

speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \
Expand Down Expand Up @@ -711,18 +734,3 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel):
taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,
brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup]
df.to_csv(f'{PATH}/jitconn_matvec_grad_gpu.csv', index=False)

# if (bm.get_platform() == 'gpu'):
# for _s in s:
# for _p in p:
# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p)
# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0]
# df.to_csv('event_ell_gpu.csv', index=False)

# df = pd.read_csv('event_ell_gpu.csv')
# for _s in s:
# for _p in p:
# brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p)
# # 找到对应的行
# df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time
# df.to_csv('event_ell_gpu.csv', index=False)
Loading

0 comments on commit 6c3eddd

Please sign in to comment.