From 22296144458f68774412cd3cae74e6b732e15497 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 3 Jan 2024 11:47:39 +0800 Subject: [PATCH 1/7] [math] Support taichi customized op with metal cpu backend --- .../_src/math/op_register/taichi_aot_based.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 06d0508a1..04ec7a690 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -2,6 +2,7 @@ import inspect import os import pathlib +import platform import re from functools import partial, reduce from typing import Any, Sequence @@ -59,10 +60,20 @@ def get_source_with_dependencies(func, visited=None): return source +# check if Metal is supported +def is_metal_supported(): + # first check if we are on macOS + if platform.system() != 'Darwin': + return False + + if platform.processor() != 'arm': + return False + return True ### VARIABLES ### home_path = get_home_dir() kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels') +is_metal_device = is_metal_supported() # check if a kernel exists in the database @@ -124,7 +135,11 @@ def _build_kernel( # init arch arch = None if device == 'cpu': - arch = ti.x64 + if is_metal_device: + arch = ti.arm64 + device == 'arm64' + else: + arch = ti.x64 elif device == 'gpu': arch = ti.cuda @@ -328,9 +343,14 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs): in_out_info = _compile_kernel(kernel, c, 'cpu', *ins, **kwargs) ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins) + if is_metal_supported: + fn = b'taichi_kernel_aot_call_cpu_arm64' + else: + fn = b'taichi_kernel_aot_call_cpu' + return xla_client.ops.CustomCallWithLayout( c, - b'taichi_kernel_aot_call_cpu', + fn, operands=ins, operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), shape_with_layout=xla_client.Shape.tuple_shape( From 9a13306694b67701343b220f28258966e7e95f83 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 4 Jan 2024 20:11:27 +0800 Subject: [PATCH 2/7] Update taichi_aot_based.py --- brainpy/_src/math/op_register/taichi_aot_based.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 04ec7a690..49d507b8d 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -343,7 +343,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs): in_out_info = _compile_kernel(kernel, c, 'cpu', *ins, **kwargs) ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins) - if is_metal_supported: + if is_metal_device: fn = b'taichi_kernel_aot_call_cpu_arm64' else: fn = b'taichi_kernel_aot_call_cpu' From f582794ab3c8786b291d94d4a15b2237be38883b Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 4 Jan 2024 20:12:24 +0800 Subject: [PATCH 3/7] Update taichi_aot_based.py --- brainpy/_src/math/op_register/taichi_aot_based.py | 1 + 1 file changed, 1 insertion(+) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 49d507b8d..5d053611f 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -343,6 +343,7 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs): in_out_info = _compile_kernel(kernel, c, 'cpu', *ins, **kwargs) ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins) + fn = None if is_metal_device: fn = b'taichi_kernel_aot_call_cpu_arm64' else: From a3dd5a53a44378e5a459f9ddfa7f8ae253e570f9 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 9 Jan 2024 11:55:58 +0800 Subject: [PATCH 4/7] Update benchmarks --- .../event_csrmv_taichi_VS_event_csrmv.py | 8 +- .../event_csrmv_taichi_VS_event_csrmv_grad.py | 388 ++++++++++ ...t_matvec_taichi_VS_jitconn_event_matvec.py | 6 +- ...vec_taichi_VS_jitconn_event_matvec_grad.py | 726 +++++++++++++++++ ...jitconn_matvec_taichi_VS_jitconn_matvec.py | 4 +- ...nn_matvec_taichi_VS_jitconn_matvec_grad.py | 728 ++++++++++++++++++ brainpy/_src/math/sparse/_csr_mv_taichi.py | 10 +- .../sparse/tests/csrmv_taichi_VS_csrmv.py | 43 +- .../tests/csrmv_taichi_VS_csrmv_grad.py | 340 ++++++++ 9 files changed, 2220 insertions(+), 33 deletions(-) create mode 100644 brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py create mode 100644 brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py create mode 100644 brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py create mode 100644 brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py index 8e290fa35..658c73573 100644 --- a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py +++ b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('gpu') +bm.set_platform('cpu') s = [1000, 5000, 10000, 20000, 25000, 30000] p = [0.1, 0.2, 0.3, 0.4, 0.5] @@ -46,7 +46,7 @@ def test_event_csrmv_cpu(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.3)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. @@ -151,7 +151,7 @@ def test_event_csrmv_cpu(shape, values_type, events_type, transpose): def test_event_csrmv_gpu(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.3)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. @@ -572,4 +572,4 @@ def test_event_csrmv_square_gpu(s, p, values_type, events_type, transpose): # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) # # 找到对应的行 # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) + # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py new file mode 100644 index 000000000..f4feb7f9e --- /dev/null +++ b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py @@ -0,0 +1,388 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +s = [1000, 5000, 10000, 20000, 25000, 30000] +p = [0.1, 0.2, 0.3, 0.4, 0.5] + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 +] + + + +values_type = [ + 'homo', + 'heter' + ] +events_type = [ + 'bool', + 'float', + ] +transpose = [ + True, + False + ] + +print(bm.get_platform()) + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +def sum_op2(op): + def func(*args, **kwargs): + r = op(*args, **kwargs)[0] + return r.sum() + + return func + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose): + return jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmv_grad(weight, indices, indptr, vector, shape, transpose): + return jax.grad(sum_op(bm.event.csrmv), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + + + +def test_event_csrmv_cpu(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') + vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 + weight = 1. + + + if events_type == 'float': + vector = vector.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + + # assert(bm.allclose(result1, result2)) + + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + + time12 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + + # assert(jnp.allclose(result1, result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_event_csrmv_gpu(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') + vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 + weight = 1. + + + if events_type == 'float': + vector = vector.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + + + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(event_csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + + time12 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(event_csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + + # if not bm.allclose(result1, result2): + # print('False') + # diff = result1 - result2 + # print(diff[:1000]) + # print(diff[-1000:]) + + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'speedup']) + +### SQUARE MATRIX + +# if (bm.get_platform() == 'cpu'): +# for _s in s: +# for _p in p: +# for _values_type in values_type: +# for _events_type in events_type: +# for _transpose in transpose: +# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ +# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_square_cpu(_s, _p, _values_type, _events_type, _transpose) +# # append to dataframe +# df.loc[df.shape[0]] = [_s, _p, _s, _s, 'cpu', _values_type, _events_type, _transpose, +# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, +# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] +# df.to_csv(f'{PATH}/event_csrmv_square_cpu.csv', index=False) + +# if (bm.get_platform() == 'gpu'): +# for _s in s: +# for _p in p: +# for _values_type in values_type: +# for _events_type in events_type: +# for _transpose in transpose: +# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ +# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_square_gpu(_s, _p, _values_type, _events_type, _transpose) +# # append to dataframe +# df.loc[df.shape[0]] = [_s, _p, _s, _s, 'gpu', _values_type, _events_type, _transpose, +# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, +# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] +# df.to_csv(f'{PATH}/event_csrmv_square_gpu.csv', index=False) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_cpu((shape1, shape2), _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2,'cpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/event_csrmv_grad_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_gpu((shape1, shape2), _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/event_csrmv_grad_gpu.csv', index=False) + + +# if (bm.get_platform() == 'gpu'): +# for _s in s: +# for _p in p: +# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) +# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] +# df.to_csv('event_ell_gpu.csv', index=False) + + # df = pd.read_csv('event_ell_gpu.csv') + # for _s in s: + # for _p in p: + # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) + # # 找到对应的行 + # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time + # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py index 249438a48..94c96dabb 100644 --- a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('cpu') +bm.set_platform('gpu') seed = 1234 @@ -42,7 +42,7 @@ True, False ] -conn_prob = 0.1 +conn_prob = 0.05 homo_data = 1. w_low = 0. w_high = 1. @@ -705,4 +705,4 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) # # 找到对应的行 # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) + # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py new file mode 100644 index 000000000..3603778dd --- /dev/null +++ b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py @@ -0,0 +1,726 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('gpu') +# bm.disable_gpu_memory_preallocation() + +seed = 1234 + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 + ] +types = [ + 'homo', + 'uniform', + 'normal' + ] +transpose = [ + True, + False + ] +outdim_parallel = [ + True, + False, + ] +bool_event = [ + True, + False + ] +conn_prob = 0.05 +homo_data = 1. +w_low = 0. +w_high = 1. +w_mu = 0. +w_sigma = 0.1 + +print(bm.get_platform()) + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs)[0] + return r.sum() + + return func + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_event_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.event_mv_prob_homo_taichi), argnums=0)( + vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_event_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.event_mv_prob_homo), argnums=0)( + vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform_taichi), argnums=0)( + vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform), argnums=0)( + vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.event_mv_prob_normal_taichi), argnums=0)( + vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.event_mv_prob_normal), argnums=0)( + vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + + # groundtruth = bm.as_jax(events, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + print('start') + + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + # bm.clear_buffer_memory() + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + # bm.clear_buffer_memory() + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + # bm.clear_buffer_memory() + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + # bm.clear_buffer_memory() + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + # bm.clear_buffer_memory() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + print('taichi finished') + + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + print('brainpylib finished') + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + bm.clear_buffer_memory() + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + print('start') + + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + print('taichi finished') + + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + print('brainpylib finished') + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + bm.clear_buffer_memory() + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + bm.clear_buffer_memory() + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + + +def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel, bool_event): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event) + elif _type == 'uniform': + return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel, bool_event) + elif _type == 'normal': + return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel, bool_event) + else: + raise ValueError + + +def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel, bool_event) + elif _type == 'uniform': + return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel, bool_event) + elif _type == 'normal': + return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel, bool_event) + else: + raise ValueError + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'speedup']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + for _bool_event in bool_event: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/jitconn_event_matvec_grad_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + for _bool_event in bool_event: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/jitconn_event_matvec_grad_gpu.csv', index=False) + +# if (bm.get_platform() == 'gpu'): +# for _s in s: +# for _p in p: +# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) +# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] +# df.to_csv('event_ell_gpu.csv', index=False) + + # df = pd.read_csv('event_ell_gpu.csv') + # for _s in s: + # for _p in p: + # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) + # # 找到对应的行 + # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time + # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py index 92def9be6..fef4d3aeb 100644 --- a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py +++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py @@ -38,7 +38,7 @@ True, False, ] -conn_prob = 0.1 +conn_prob = 0.05 homo_data = 1. w_low = 0. w_high = 1. @@ -691,4 +691,4 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) # # 找到对应的行 # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) + # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py new file mode 100644 index 000000000..512bc1511 --- /dev/null +++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py @@ -0,0 +1,728 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('gpu') + +seed = 1234 + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 + ] +types = [ + 'homo', + 'uniform', + 'normal' + ] +transpose = [ + True, + False + ] +outdim_parallel = [ + True, + False, + ] +conn_prob = 0.05 +homo_data = 1. +w_low = 0. +w_high = 1. +w_mu = 0. +w_sigma = 0.1 + +print(bm.get_platform()) + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs)[0] + return r.sum() + + return func + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)( + vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)( + vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)( + vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)( + vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)( + vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + return jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)( + vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + +def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + + +def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel) + elif _type == 'uniform': + return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel) + elif _type == 'normal': + return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel) + else: + raise ValueError + + +def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel) + elif _type == 'uniform': + return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel) + elif _type == 'normal': + return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel) + else: + raise ValueError + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'speedup']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/jitconn_matvec_grad_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/jitconn_matvec_grad_gpu.csv', index=False) + +# if (bm.get_platform() == 'gpu'): +# for _s in s: +# for _p in p: +# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) +# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] +# df.to_csv('event_ell_gpu.csv', index=False) + + # df = pd.read_csv('event_ell_gpu.csv') + # for _s in s: + # for _p in p: + # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) + # # 找到对应的行 + # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time + # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/_csr_mv_taichi.py b/brainpy/_src/math/sparse/_csr_mv_taichi.py index 73812d44b..cd09af08e 100644 --- a/brainpy/_src/math/sparse/_csr_mv_taichi.py +++ b/brainpy/_src/math/sparse/_csr_mv_taichi.py @@ -61,8 +61,8 @@ def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), for row_i in range(row_ptr.shape[0] - 1): r = 0. for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += value * vector[col_indices[j]] - out[row_i] = r + r += vector[col_indices[j]] + out[row_i] = r * value @ti.kernel @@ -115,9 +115,9 @@ def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), j = row_ptr[row_i] + index end_index = row_ptr[row_i + 1] while j < end_index: - r += value * vector[col_indices[j]] + r += vector[col_indices[j]] j += 32 - out[row_i] += r # TODO: warp-level primitive + out[row_i] += value * r @ti.kernel @@ -285,4 +285,4 @@ def _define_op(cpu_kernel, gpu_kernel): # no transpose heter _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) + gpu_kernel=_sparse_csr_matvec_heter_gpu) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py index 8ff6e1481..3ae91a036 100644 --- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py +++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py @@ -42,7 +42,7 @@ def test_sparse_csrmv_cpu(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.3)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. @@ -134,7 +134,7 @@ def test_sparse_csrmv_cpu(shape, values_type, events_type, transpose): print('brainpylib_cpu_3: ', brainpy_time3, 'ms') print('brainpylib_cpu_4: ', brainpy_time4, 'ms') print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(jnp.allclose(result1[0], result2)) + assert(bm.allclose(result1[0], result2)) speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 @@ -144,20 +144,30 @@ def test_sparse_csrmv_cpu(shape, values_type, events_type, transpose): def test_sparse_csrmv_gpu(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.3)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05, seed=1234)(*shape).require('pre2post') + vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. + heter_data = bm.ones(indices.shape) * weight + + # dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + + # if transpose: + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense, dtype=float) + # else: + # groundtruth = bm.as_jax(dense, dtype=float) @ bm.as_jax(vector, dtype=float) + + # groundtruth = groundtruth * weight + + if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight weight = heter_data - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) # time.sleep(2) + + # assert(bm.allclose(result1[0], groundtruth)) time0 = time.time() result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) @@ -183,16 +193,7 @@ def test_sparse_csrmv_gpu(shape, values_type, events_type, transpose): time9 = time.time() result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - time12 = time.time() result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time13 = time.time() @@ -238,8 +239,12 @@ def test_sparse_csrmv_gpu(shape, values_type, events_type, transpose): print('brainpylib_gpu_3: ', brainpy_time3, 'ms') print('brainpylib_gpu_4: ', brainpy_time4, 'ms') print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + + # print('------------------------------------------------------') + # print(result1[0]) + # print('------------------------------------------------------') + # print(result2) - # assert(jnp.allclose(result1[0], result2)) speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 @@ -554,4 +559,4 @@ def test_sparse_csrmv_square_gpu(s, p, values_type, events_type, transpose): # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) # # 找到对应的行 # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) + # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py new file mode 100644 index 000000000..c267044a0 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py @@ -0,0 +1,340 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('gpu') + +s = [1000, + 5000, + 10000, + 15000, + 20000, + 25000, + 30000] +p = [0.1, 0.2, 0.3, 0.4, 0.5] + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 +] + +values_type = [ + 'homo', + 'heter' + ] +events_type = ['float'] +transpose = [ + True, + False + ] +method = 'cusparse' + +print(bm.get_platform()) + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +def sum_op2(op): + def func(*args, **kwargs): + r = op(*args, **kwargs)[0] + return r.sum() + + return func + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose): + return jax.value_and_grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmv_grad(weight, indices, indptr, vector, shape, transpose): + return jax.value_and_grad(sum_op(bm.sparse.csrmv), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + +def test_sparse_csrmv_cpu(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + indices, indptr = bp.conn.FixedProb(0.05, allow_multi_conn=True)(*shape).require('pre2post') + vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 + weight = 1. + + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + tuple0, result1 = csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + # time.sleep(2) + + time0 = time.time() + tuple0, result1 = csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + tuple0, result1 = csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + tuple0, result1 = csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + tuple0, result1 = csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time7 = time.time() + + time8 = time.time() + tuple0, result1 = csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time9 = time.time() + + tuple1, result2 = csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + + time12 = time.time() + tuple1, result2 = csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + tuple1, result2 = csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + tuple1, result2 = csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + tuple1, result2 = csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time19 = time.time() + + time20 = time.time() + tuple1, result2 = csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + + # assert(bm.allclose(tuple0, result1, tuple0, result2)) + print('1:',tuple0) + print('2:',tuple1) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_sparse_csrmv_gpu(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + indices, indptr = bp.conn.FixedProb(0.05, allow_multi_conn=True)(*shape).require('pre2post') + vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 + weight = 1. + + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time9 = time.time() + + tuple1, result2 = jax.block_until_ready(csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + + time12 = time.time() + tuple1, result2 = jax.block_until_ready(csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + tuple1, result2 = jax.block_until_ready(csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + tuple1, result2 = jax.block_until_ready(csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + tuple1, result2 = jax.block_until_ready(csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + tuple1, result2 = jax.block_until_ready(csrmv_grad( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_gpu_1: ', brainpy_time1, 'ms') + print('brainpylib_gpu_2: ', brainpy_time2, 'ms') + print('brainpylib_gpu_3: ', brainpy_time3, 'ms') + print('brainpylib_gpu_4: ', brainpy_time4, 'ms') + print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + + # print(tuple0, result1 - tuple0, result2) + print('1:',tuple0) + print('2:',tuple1) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'speedup']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_cpu((shape1, shape2), _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [(shape1, shape2), 0.3 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/csrmv_grad_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_gpu((shape1, shape2), _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [(shape1, shape2), 0.3 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/csrmv_grad_gpu.csv', index=False) + +# if (bm.get_platform() == 'gpu'): +# for _s in s: +# for _p in p: +# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) +# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] +# df.to_csv('event_ell_gpu.csv', index=False) + + # df = pd.read_csv('event_ell_gpu.csv') + # for _s in s: + # for _p in p: + # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) + # # 找到对应的行 + # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time + # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file From 6c3eddd48c235cf17c35bf33bc8383ddab295bfb Mon Sep 17 00:00:00 2001 From: routhleck <1310722434@qq.com> Date: Fri, 12 Jan 2024 09:31:18 +0800 Subject: [PATCH 5/7] New benchmark method --- .../event_csrmv_taichi_VS_event_csrmv.py | 559 +++---------- .../event_csrmv_taichi_VS_event_csrmv_grad.py | 353 +++----- ...t_matvec_taichi_VS_jitconn_event_matvec.py | 791 ++++++++---------- ...vec_taichi_VS_jitconn_event_matvec_grad.py | 783 +++++++---------- ...jitconn_matvec_taichi_VS_jitconn_matvec.py | 788 ++++++++--------- ...nn_matvec_taichi_VS_jitconn_matvec_grad.py | 106 +-- .../sparse/tests/csrmv_taichi_VS_csrmv.py | 558 +++--------- .../tests/csrmv_taichi_VS_csrmv_grad.py | 317 +++---- 8 files changed, 1520 insertions(+), 2735 deletions(-) diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py index 658c73573..3ac1e0ee2 100644 --- a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py +++ b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py @@ -42,11 +42,29 @@ False ] +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) -def test_event_csrmv_cpu(shape, values_type, events_type, transpose): +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmv_taichi(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0] + return r + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmv(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose) + return r + +def test_event_csrmv(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. @@ -57,477 +75,146 @@ def test_event_csrmv_cpu(shape, values_type, events_type, transpose): heter_data = bm.ones(indices.shape) * weight weight = heter_data - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # time.sleep(2) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # print(result1[0]) - # print(result2) - # print(groundtruth - result1[0]) - # print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_event_csrmv_gpu(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. + time10 = time.time() + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time11 = time.time() - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # print(result1[0]) - # print(result2) - # print(groundtruth - result1[0]) - # print(groundtruth - result2) - - print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - time12 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - - -def test_event_csrmv_square_cpu(s, p, values_type, events_type, transpose): - print('s: ', s, 'p: ', p) - k = int(s * p) - bm.random.seed(1234) - rng = bm.random.RandomState(seed=1234) - # init - indices = bm.random.randint(0, s, (s, k)) - vector = bm.random.rand(s) < 0.5 - weight = jnp.array([1.0]) - csr_indices = indices.flatten() - csr_indptr = np.cumsum(np.insert(np.ones(s, dtype=int) * k, 0, 0)) - - pre_indices = np.repeat(np.arange(s), k) - dense = np.zeros((s, s)) - dense[pre_indices, csr_indices] = 1.0 - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.as_jax(rng.random(csr_indices.shape)) - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # print(result1[0]) - # print(result2) - # print(groundtruth - result1[0]) - # print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - time12 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time19 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 + time22 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time23 = time.time() - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(jnp.allclose(result1[0], result2)) + time24 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time25 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time26 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time27 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_event_csrmv_square_gpu(s, p, values_type, events_type, transpose): - print('s: ', s, 'p: ', p) - k = int(s * p) - bm.random.seed(1234) - rng = bm.random.RandomState(seed=1234) - # init - indices = bm.random.randint(0, s, (s, k)) - vector = bm.random.rand(s) < 0.5 - weight = jnp.array([1.0]) - csr_indices = indices.flatten() - csr_indptr = np.cumsum(np.insert(np.ones(s, dtype=int) * k, 0, 0)) - pre_indices = np.repeat(np.arange(s), k) - dense = np.zeros((s, s)) - dense[pre_indices, csr_indices] = 1.0 - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.as_jax(rng.random(csr_indices.shape)) - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # print('--------------------result1[0]------------------') - # print(result1[0]) - # print('--------------------result2------------------') - # print(result2) - # print('--------------------gt------------------') - # print(groundtruth) - # print('--------------------gt - result1[0]------------------') - # print(groundtruth - result1[0]) - # print('--------------------gt - result2------------------') - # print(groundtruth - result2) + time28 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time21 = time.time() + time30 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - print('s: ', s, 'p: ', p, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + # assert(jnp.allclose(result1[0], result2)) return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) - -### SQUARE MATRIX - -# if (bm.get_platform() == 'cpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_square_cpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, _s, _s, 'cpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/event_csrmv_square_cpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_square_gpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, _s, _s, 'gpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/event_csrmv_square_gpu.csv', index=False) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): @@ -537,11 +224,15 @@ def test_event_csrmv_square_gpu(s, p, values_type, events_type, transpose): for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_cpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2,'cpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/event_csrmv_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): @@ -551,25 +242,13 @@ def test_event_csrmv_square_gpu(s, p, values_type, events_type, transpose): for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_gpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/event_csrmv_gpu.csv', index=False) - - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py index f4feb7f9e..98793e600 100644 --- a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py +++ b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py @@ -42,6 +42,10 @@ False ] +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) def sum_op(op): @@ -61,20 +65,24 @@ def func(*args, **kwargs): @partial(jax.jit, static_argnums=(4, 5)) def event_csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose): - return jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + return r @partial(jax.jit, static_argnums=(4, 5)) def event_csrmv_grad(weight, indices, indptr, vector, shape, transpose): - return jax.grad(sum_op(bm.event.csrmv), argnums=3)( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.event.csrmv), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + return r -def test_event_csrmv_cpu(shape, values_type, events_type, transpose): +def test_event_csrmv(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. @@ -85,262 +93,145 @@ def test_event_csrmv_cpu(shape, values_type, events_type, transpose): heter_data = bm.ones(indices.shape) * weight weight = heter_data - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - - # assert(bm.allclose(result1, result2)) - - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - + + time10 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - - # assert(jnp.allclose(result1, result2)) + time22 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time23 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time24 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time25 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + time26 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time27 = time.time() -def test_event_csrmv_gpu(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. + time28 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time29 = time.time() - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(event_csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - - time12 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(event_csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - # if not bm.allclose(result1, result2): - # print('False') - # diff = result1 - result2 - # print(diff[:1000]) - # print(diff[-1000:]) - - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) - -### SQUARE MATRIX - -# if (bm.get_platform() == 'cpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_square_cpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, _s, _s, 'cpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/event_csrmv_square_cpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_square_gpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, _s, _s, 'gpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/event_csrmv_square_gpu.csv', index=False) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): @@ -350,11 +241,15 @@ def test_event_csrmv_gpu(shape, values_type, events_type, transpose): for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_cpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2,'cpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/event_csrmv_grad_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): @@ -364,25 +259,13 @@ def test_event_csrmv_gpu(shape, values_type, events_type, transpose): for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_gpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/event_csrmv_grad_gpu.csv', index=False) - - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py index 94c96dabb..21a246650 100644 --- a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('gpu') +bm.set_platform('cpu') seed = 1234 @@ -49,9 +49,56 @@ w_mu = 0. w_sigma = 0.1 +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) -def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event): +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_event_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + return r + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_event_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + + +def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: @@ -59,607 +106,432 @@ def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event): # groundtruth = bm.as_jax(events, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() -def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel, bool_event): + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel, bool_event): +def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: events = events.astype(float) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel, bool_event): +def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event): print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) if _type == 'homo': - return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event) elif _type == 'uniform': - return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event) elif _type == 'normal': - return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event) else: raise ValueError -def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel, bool_event) - elif _type == 'uniform': - return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel, bool_event) - elif _type == 'normal': - return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel, bool_event) - else: - raise ValueError - PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): @@ -670,11 +542,15 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event for _transpose in transpose: for _bool_event in bool_event: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) # append to dataframe df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_event_matvec_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): @@ -685,24 +561,13 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event for _transpose in transpose: for _bool_event in bool_event: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) # append to dataframe df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_event_matvec_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py index 3603778dd..ff4f01afc 100644 --- a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py +++ b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('gpu') +bm.set_platform('cpu') # bm.disable_gpu_memory_preallocation() seed = 1234 @@ -59,615 +59,483 @@ def func(*args, **kwargs): return func +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + @partial(jax.jit, static_argnums=(4, 5, 6)) def jitconn_event_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.event_mv_prob_homo_taichi), argnums=0)( - vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r +=jax.grad(sum_op(bm.jitconn.event_mv_prob_homo_taichi), argnums=0)( + vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(4, 5, 6)) def jitconn_event_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.event_mv_prob_homo), argnums=0)( - vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_homo), argnums=0)( + vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(5, 6, 7)) def jitconn_event_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform_taichi), argnums=0)( - vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform_taichi), argnums=0)( + vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(5, 6, 7)) def jitconn_event_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform), argnums=0)( - vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform), argnums=0)( + vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(5, 6, 7)) def jitconn_event_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.event_mv_prob_normal_taichi), argnums=0)( - vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal_taichi), argnums=0)( + vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(5, 6, 7)) def jitconn_event_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.event_mv_prob_normal), argnums=0)( - vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) - -def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal), argnums=0)( + vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: - events = events.astype(float) - - # groundtruth = bm.as_jax(events, dtype=float) @ bm.as_jax(dense) + events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - + + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time9 = time.time() - - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - - time12 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time13 = time.time() - # time.sleep(2) + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - time14 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time17 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - time18 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time19 = time.time() + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - time20 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time21 = time.time() + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel, bool_event): +def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: events = events.astype(float) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) + + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() time12 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() -def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - print('start') + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() - # bm.clear_buffer_memory() - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time1 = time.time() - # time.sleep(2) + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() - # bm.clear_buffer_memory() - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time3 = time.time() - # time.sleep(2) + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() - # bm.clear_buffer_memory() - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time5 = time.time() - # time.sleep(2) + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() - # bm.clear_buffer_memory() - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time7 = time.time() - - # bm.clear_buffer_memory() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time9 = time.time() - - print('taichi finished') - - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - - time12 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time21 = time.time() - - print('brainpylib finished') + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - bm.clear_buffer_memory() return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel, bool_event): +def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: events = events.astype(float) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - print('start') - - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - print('taichi finished') - - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time21 = time.time() - print('brainpylib finished') - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - bm.clear_buffer_memory() - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time9 = time.time() + time20 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - time12 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time13 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - time14 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time15 = time.time() - # time.sleep(2) + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - time16 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) - time21 = time.time() + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - bm.clear_buffer_memory() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel, bool_event): +def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event): print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) if _type == 'homo': - return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event) elif _type == 'uniform': - return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event) elif _type == 'normal': - return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel, bool_event) - else: - raise ValueError - - -def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel, bool_event) - elif _type == 'uniform': - return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel, bool_event) - elif _type == 'normal': - return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event) else: raise ValueError @@ -675,9 +543,11 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event # init dataframe df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) + ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): @@ -688,11 +558,15 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event for _transpose in transpose: for _bool_event in bool_event: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) # append to dataframe df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_event_matvec_grad_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): @@ -703,24 +577,13 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event for _transpose in transpose: for _bool_event in bool_event: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) # append to dataframe df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_event_matvec_grad_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py index fef4d3aeb..14a19aefb 100644 --- a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py +++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py @@ -38,6 +38,7 @@ True, False, ] +bool_event = False conn_prob = 0.05 homo_data = 1. w_low = 0. @@ -45,609 +46,481 @@ w_mu = 0. w_sigma = 0.1 +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) -def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel): +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +def test_jitconn_matvec_homo(shape, transpose, outdim_parallel): rng = bm.random.RandomState(seed=seed) vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 + time22 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + time24 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time26 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel): +def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) + time22 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() + time26 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel): +def test_jitconn_matvec_normal(shape, transpose, outdim_parallel): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + time22 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time24 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + time26 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() -def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - - -def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel) - elif _type == 'uniform': - return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel) - elif _type == 'normal': - return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel) - else: - raise ValueError - + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): +def test_jitconn_matvec(shape, _type, transpose, outdim_parallel): print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) if _type == 'homo': - return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel) + return test_jitconn_matvec_homo(shape, transpose, outdim_parallel) elif _type == 'uniform': - return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel) + return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel) elif _type == 'normal': - return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel) + return test_jitconn_matvec_normal(shape, transpose, outdim_parallel) else: raise ValueError PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe -df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', +df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): @@ -657,11 +530,15 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): for _outdim_parallel in outdim_parallel: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel) # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_matvec_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): @@ -671,24 +548,13 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): for _outdim_parallel in outdim_parallel: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel) # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_matvec_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py index 512bc1511..165c9b19b 100644 --- a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py +++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('gpu') +bm.set_platform('cpu') seed = 1234 @@ -25,6 +25,7 @@ 37500, 50000 ] +bool_event = False types = [ 'homo', 'uniform', @@ -45,6 +46,10 @@ w_mu = 0. w_sigma = 0.1 +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) def sum_op(op): @@ -56,39 +61,57 @@ def func(*args, **kwargs): @partial(jax.jit, static_argnums=(4, 5, 6)) def jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)( - vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)( + vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(4, 5, 6)) def jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)( - vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)( + vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(5, 6, 7)) def jitconn_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)( - vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)( + vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(5, 6, 7)) def jitconn_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)( - vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)( + vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(5, 6, 7)) def jitconn_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)( - vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)( + vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r @partial(jax.jit, static_argnums=(5, 6, 7)) def jitconn_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): - return jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)( - vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel - ) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)( + vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel): rng = bm.random.RandomState(seed=seed) @@ -448,11 +471,11 @@ def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel): print('taichi_aot_3: ', taichi_aot_time3, 'ms') print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_2: ', brainpy_time2, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_4: ', brainpy_time4, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') # assert(jnp.allclose(result1[0], result2)) speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ @@ -543,11 +566,11 @@ def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel): print('taichi_aot_3: ', taichi_aot_time3, 'ms') print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_2: ', brainpy_time2, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_4: ', brainpy_time4, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') # assert(jnp.allclose(result1[0], result2)) speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ @@ -638,11 +661,11 @@ def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel): print('taichi_aot_3: ', taichi_aot_time3, 'ms') print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_2: ', brainpy_time2, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_4: ', brainpy_time4, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') # assert(jnp.allclose(result1[0], result2)) speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ @@ -711,18 +734,3 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] df.to_csv(f'{PATH}/jitconn_matvec_grad_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py index 3ae91a036..1db246212 100644 --- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py +++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('gpu') +bm.set_platform('cpu') s = [1000, 5000, 10000, 15000, 20000, 25000, 30000] p = [0.1, 0.2, 0.3, 0.4, 0.5] @@ -38,525 +38,213 @@ ] method = 'cusparse' +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) -def test_sparse_csrmv_cpu(shape, values_type, events_type, transpose): +@partial(jax.jit, static_argnums=(4, 5)) +def csrmv_taichi(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0] + return r + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmv(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose) + return r + +def test_sparse_csrmv(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. + + if events_type == 'float': + vector = vector.astype(bm.float32) if values_type == 'heter': heter_data = bm.ones(indices.shape) * weight weight = heter_data - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(bm.allclose(result1[0], result2)) + time22 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time23 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time24 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time25 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + time26 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time27 = time.time() -def test_sparse_csrmv_gpu(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05, seed=1234)(*shape).require('pre2post') - - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. - - heter_data = bm.ones(indices.shape) * weight - - # dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=shape) + time28 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time29 = time.time() - # if transpose: - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense, dtype=float) - # else: - # groundtruth = bm.as_jax(dense, dtype=float) @ bm.as_jax(vector, dtype=float) + time30 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time31 = time.time() - # groundtruth = groundtruth * weight + time32 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time33 = time.time() + time34 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time35 = time.time() - if values_type == 'heter': - weight = heter_data - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # time.sleep(2) + time36 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time37 = time.time() - # assert(bm.allclose(result1[0], groundtruth)) - - time0 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - - time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() + time38 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - # print('------------------------------------------------------') - # print(result1[0]) - # print('------------------------------------------------------') - # print(result2) - + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_sparse_csrmv_square_cpu(s, p, values_type, events_type, transpose): - print('s: ', s, 'p: ', p) - k = int(s * p) - rng = bm.random.RandomState(seed=1234) - # init - indices = bm.random.randint(0, s, (s, k)) - vector = rng.random(s) - weight = jnp.array([1.0]) - csr_indices = indices.flatten() - csr_indptr = np.cumsum(np.insert(np.ones(s, dtype=int) * k, 0, 0)) - - pre_indices = np.repeat(np.arange(s), k) - dense = np.zeros((s, s)) - dense[pre_indices, csr_indices] = 1.0 - - if values_type == 'heter': - heter_data = bm.as_jax(rng.random(csr_indices.shape)) - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_sparse_csrmv_square_gpu(s, p, values_type, events_type, transpose): - print('s: ', s, 'p: ', p) - k = int(s * p) - bm.random.seed(1234) - rng = bm.random.RandomState(seed=1234) - # init - indices = bm.random.randint(0, s, (s, k)) - vector = rng.random(s) - weight = jnp.array([1.0]) - csr_indices = indices.flatten() - csr_indptr = np.cumsum(np.insert(np.ones(s, dtype=int) * k, 0, 0)) - pre_indices = np.repeat(np.arange(s), k) - dense = np.zeros((s, s)) - dense[pre_indices, csr_indices] = 1.0 - - if values_type == 'heter': - heter_data = bm.as_jax(rng.random(csr_indices.shape)) - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - # print('--------------------result1[0]------------------') - # print(result1[0]) - # print('--------------------result2------------------') - # print(result2) - # print('--------------------gt - result1[0]------------------') - # print(groundtruth - result1[0]) - # print('--------------------gt - result2------------------') - # print(groundtruth - result2) - - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) - -### SQUARE MATRIX -# if (bm.get_platform() == 'cpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_square_cpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, 'cpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/csrmv_square_cpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_square_gpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, 'gpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/csrmv_square_gpu.csv', index=False) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): for shape1 in shape: for shape2 in shape: - for _values_type in values_type: + for _values_type in values_type: for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_cpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.3 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/csrmv_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): for shape1 in shape: for shape2 in shape: - for _values_type in values_type: + for _values_type in values_type: for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_gpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.3 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/csrmv_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py index c267044a0..d902c9395 100644 --- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py +++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('gpu') +bm.set_platform('cpu') s = [1000, 5000, @@ -44,6 +44,10 @@ ] method = 'cusparse' +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) def sum_op(op): @@ -63,278 +67,207 @@ def func(*args, **kwargs): @partial(jax.jit, static_argnums=(4, 5)) def csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose): - return jax.value_and_grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + return r @partial(jax.jit, static_argnums=(4, 5)) def csrmv_grad(weight, indices, indptr, vector, shape, transpose): - return jax.value_and_grad(sum_op(bm.sparse.csrmv), argnums=3)( + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - -def test_sparse_csrmv_cpu(shape, values_type, events_type, transpose): + return r + +def test_sparse_csrmv(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05, allow_multi_conn=True)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. + + if events_type == 'float': + vector = vector.astype(bm.float32) if values_type == 'heter': heter_data = bm.ones(indices.shape) * weight weight = heter_data - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - tuple0, result1 = csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - # time.sleep(2) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time0 = time.time() - tuple0, result1 = csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - tuple0, result1 = csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - tuple0, result1 = csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - tuple0, result1 = csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time7 = time.time() time8 = time.time() - tuple0, result1 = csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time9 = time.time() - - tuple1, result2 = csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - + + time10 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time11 = time.time() + time12 = time.time() - tuple1, result2 = csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - tuple1, result2 = csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - tuple1, result2 = csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - tuple1, result2 = csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time20 = time.time() - tuple1, result2 = csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - - # assert(bm.allclose(tuple0, result1, tuple0, result2)) - print('1:',tuple0) - print('2:',tuple1) + time22 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time23 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time24 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time25 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + time26 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time27 = time.time() -def test_sparse_csrmv_gpu(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.05, allow_multi_conn=True)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. + time28 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time29 = time.time() - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - tuple0, result1 = jax.block_until_ready(csrmv_taichi_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time9 = time.time() - - tuple1, result2 = jax.block_until_ready(csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - - time12 = time.time() - tuple1, result2 = jax.block_until_ready(csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - tuple1, result2 = jax.block_until_ready(csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - tuple1, result2 = jax.block_until_ready(csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - tuple1, result2 = jax.block_until_ready(csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - tuple1, result2 = jax.block_until_ready(csrmv_grad( - weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - # print(tuple0, result1 - tuple0, result2) - print('1:',tuple0) - print('2:',tuple1) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) + ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): for shape1 in shape: for shape2 in shape: - for _values_type in values_type: + for _values_type in values_type: for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_cpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.3 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/csrmv_grad_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): for shape1 in shape: for shape2 in shape: - for _values_type in values_type: + for _values_type in values_type: for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_gpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.3 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/csrmv_grad_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) \ No newline at end of file From c3c9cbf182f7b5e8bca83e9a6991c7402e3ee04a Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 12 Jan 2024 16:31:06 +0800 Subject: [PATCH 6/7] fix bug --- .../_src/math/op_register/taichi_aot_based.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 5d053611f..0276e6e79 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -12,8 +12,8 @@ from jax.interpreters import xla from jax.lib import xla_client -from .utils import _shape_to_layout from brainpy._src.dependency_check import import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops +from .utils import _shape_to_layout ### UTILS ### @@ -37,38 +37,37 @@ def encode_md5(source: str) -> str: return md5.hexdigest() +# TODO +# not a very good way # get source with dependencies def get_source_with_dependencies(func, visited=None): if visited is None: visited = set() source = inspect.getsource(func) - if func in visited: return '' visited.add(func) - module = inspect.getmodule(func) - dependent_funcs = re.findall(r'(\w+)\(', source) for func_name in dependent_funcs: dependent_func = getattr(module, func_name, None) if callable(dependent_func): source += get_source_with_dependencies(dependent_func, visited) - return source + # check if Metal is supported def is_metal_supported(): - # first check if we are on macOS - if platform.system() != 'Darwin': - return False + # first check if we are on macOS + if platform.system() != 'Darwin': + return False + if platform.processor() != 'arm': + return False + return True - if platform.processor() != 'arm': - return False - return True ### VARIABLES ### home_path = get_home_dir() @@ -133,15 +132,16 @@ def _build_kernel( ti = import_taichi() # init arch - arch = None if device == 'cpu': if is_metal_device: arch = ti.arm64 - device == 'arm64' + device = 'arm64' else: arch = ti.x64 elif device == 'gpu': arch = ti.cuda + else: + raise ValueError(f'Unknown device: {device}') ti.init(arch=arch) @@ -343,7 +343,6 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs): in_out_info = _compile_kernel(kernel, c, 'cpu', *ins, **kwargs) ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins) - fn = None if is_metal_device: fn = b'taichi_kernel_aot_call_cpu_arm64' else: From e40e9d438b9f8f8673d11dc7d225c8a229eb452f Mon Sep 17 00:00:00 2001 From: chaoming Date: Fri, 12 Jan 2024 16:34:53 +0800 Subject: [PATCH 7/7] update error info --- brainpy/_src/math/op_register/taichi_aot_based.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 0276e6e79..ab7b98011 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -117,7 +117,9 @@ def _array_to_field(dtype, shape) -> Any: elif dtype == np.float64: dtype = ti.float64 else: - raise TypeError + raise NotImplementedError(f'Currently we do not support dtype {dtype} in Taichi. ' + f'If you think it is necessary, please open an issue at ' + f'https://github.com/brainpy/BrainPy/issues/new') return ti.field(dtype=dtype, shape=shape)