diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index 33aa803d..dfea2a6b 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -92,6 +92,7 @@ def raw_csrmm_taichi( if indices.shape[0] == 0: return [jnp.zeros(result_shape, dtype=data.dtype), ] + # homo -> taichi, # heter -> cusparse if data.shape[0] != 1: @@ -118,11 +119,11 @@ def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += values[j] * matrix[row_j, col_i] + for row_i in range(row_ptr.shape[0] - 1): + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + col = col_indices[i] + for j in range(out.shape[1]): + out[col, j] += values[row_i] * matrix[row_i, j] @ti.kernel @@ -139,17 +140,32 @@ def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), @ti.kernel -def _csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + col = col_indices[i] + for j in range(out.shape[1]): + out[col, j] += value * matrix[row_i, j] + + +@ti.kernel +def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += value * matrix[row_j, col_i] + for row_i in range(row_ptr.shape[0] - 1): + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + col = col_indices[i] + for j in range(out.shape[1]): + out[col, j] += value * matrix[row_i, j] @ti.kernel @@ -213,8 +229,8 @@ def _define_op(cpu_kernel, gpu_kernel): gpu_kernel=_csr_matmat_heter) # transpose homo -_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo, - gpu_kernel=_csr_matmat_transpose_homo) +_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) # no transpose homo _csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py index 79c8bef0..f11275b1 100644 --- a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py +++ b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py @@ -14,163 +14,180 @@ bm.set_platform('cpu') -s = [1000, 5000, 10000, 15000, 20000, 25000, 30000] -p = [0.1, 0.2, 0.3, 0.4, 0.5] +SPARSITY = 0.05 size = [ - (100, 100, 100), - (100, 1000, 100), - (1000, 1000, 100), - (1000, 1000, 1000), - (100, 10000, 100), - (10000, 100, 1000), - (1000, 100, 10000), - (10000, 10000, 1000), - (10000, 1000, 10000), - (10000, 10000, 10000), - (20000, 20000, 20000), + (100, 100, 100), + (100, 1000, 100), + (1000, 1000, 100), + (1000, 1000, 1000), + (100, 10000, 100), + (10000, 100, 1000), + (1000, 100, 10000), + (10000, 10000, 1000), + (10000, 1000, 10000), + (10000, 10000, 10000), + (20000, 20000, 20000), ] values_type = [ - 'heter' - ] + 'homo', + # 'heter' +] events_type = ['float'] transpose = [ - True, - # False - ] + True, + False +] -ITERATION = 100 +ITERATION = 10 if bm.get_platform() == 'cpu': - ITERATION = 10 + ITERATION = 3 print(bm.get_platform()) + @partial(jax.jit, static_argnums=(4, 5)) def csrmm_taichi(weight, indices, indptr, matrix, shape, transpose): r = 0 for i in range(ITERATION): r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method=None) return r - + + @partial(jax.jit, static_argnums=(4, 5)) def csrmm(weight, indices, indptr, matrix, shape, transpose): r = 0 for i in range(ITERATION): - r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method='cusparse') + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method='jaxlib') return r + def test_sparse_csrmm(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) matrix2_shape = (shape[1], shape[2]) - indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(SPARSITY, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') matrix = rng.random(matrix2_shape) matrix = bm.as_jax(matrix) weight = 1. - - + + heter_data = bm.ones(indices.shape) * weight if events_type == 'float': matrix = matrix.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data + # if values_type == 'heter': + # weight = heter_data - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time0 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time1 = time.time() time2 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time3 = time.time() time4 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time5 = time.time() time6 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time7 = time.time() time8 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time9 = time.time() - + time10 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time11 = time.time() - + time12 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time13 = time.time() - + time14 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time15 = time.time() - + time16 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time17 = time.time() - + time18 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time19 = time.time() - + result1 = result - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time20 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time21 = time.time() - + result2 = result - + time22 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time23 = time.time() time24 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time25 = time.time() time26 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time27 = time.time() time28 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time29 = time.time() - + time30 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time31 = time.time() - + time32 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time33 = time.time() - + time34 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time35 = time.time() - + time36 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time37 = time.time() - + time38 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 @@ -206,53 +223,67 @@ def test_sparse_csrmm(shape, values_type, events_type, transpose): print('brainpylib_9: ', brainpy_time9, 'ms') print(bm.allclose(result1, result2)) - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5, \ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10, \ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe -df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) +df = pd.DataFrame( + columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', + 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', + 'taichi aot time10(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): for shape in size: for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.5 , shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, + _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, + taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, + taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): for shape in size: for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.5 , shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, + _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, + taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, + taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False)