From 8737b939a7ee959af443a4d6c56386edfff986e8 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 3 Mar 2024 16:53:25 +0800 Subject: [PATCH] [math] Update event csrmm and csrmm --- brainpy/_src/math/event/__init__.py | 1 + .../event/{_csr_matmat.py => csr_matmat.py} | 268 ++++++++-------- .../tests/event_csr_matmat_VS_csr_matmat.py | 285 ++++++++++++++++++ .../_src/math/event/tests/test_event_csrmm.py | 4 +- brainpy/_src/math/sparse/__init__.py | 1 + .../math/sparse/{_csr_mm.py => csr_mm.py} | 126 ++++---- .../csr_matmat_VS_cusparse_csr_matmat.py | 258 ++++++++++++++++ brainpy/_src/math/sparse/tests/test_csrmm.py | 4 +- brainpy/math/event.py | 1 + 9 files changed, 736 insertions(+), 212 deletions(-) rename brainpy/_src/math/event/{_csr_matmat.py => csr_matmat.py} (71%) create mode 100644 brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py rename brainpy/_src/math/sparse/{_csr_mm.py => csr_mm.py} (78%) create mode 100644 brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 9ebad3e9..91b479b6 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,2 +1,3 @@ from .csr_matvec import * +from .csr_matmat import * diff --git a/brainpy/_src/math/event/_csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py similarity index 71% rename from brainpy/_src/math/event/_csr_matmat.py rename to brainpy/_src/math/event/csr_matmat.py index 024f1692..6936b495 100644 --- a/brainpy/_src/math/event/_csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -12,8 +12,8 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (XLACustomOp) -from brainpy._src.math.sparse._csr_mm import raw_csrmm_taichi as normal_csrmm -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm +from brainpy._src.math.sparse.utils import csr_to_coo ti = import_taichi() @@ -125,17 +125,16 @@ def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -144,17 +143,16 @@ def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -163,13 +161,12 @@ def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -178,13 +175,12 @@ def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -194,16 +190,15 @@ def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -213,16 +208,15 @@ def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -232,13 +226,12 @@ def _event_csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value @ti.kernel @@ -248,13 +241,12 @@ def _event_csr_matmat_bool_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value # GPU kernels @@ -265,17 +257,16 @@ def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -284,17 +275,16 @@ def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -303,13 +293,12 @@ def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -318,13 +307,12 @@ def _event_csr_matmat_bool_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -334,16 +322,15 @@ def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -353,16 +340,15 @@ def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -372,13 +358,12 @@ def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value @ti.kernel @@ -388,13 +373,12 @@ def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): diff --git a/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py b/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py new file mode 100644 index 00000000..872c69e1 --- /dev/null +++ b/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py @@ -0,0 +1,285 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +size = [ + (100, 100, 100), + (100, 1000, 100), + (1000, 1000, 100), + (1000, 1000, 1000), + (100, 10000, 100), + (10000, 100, 1000), + (1000, 100, 10000), + (10000, 10000, 1000), + (10000, 1000, 10000), + (10000, 10000, 10000), + (20000, 20000, 20000), +] + +values_type = [ + 'heter', + 'homo', +] +events_type = ['bool', + 'float', + ] +transpose = [ + # True, + False +] + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +print(bm.get_platform()) + + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + return r + + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.event.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + return r + + +def test_sparse_csrmm(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) + matrix2_shape = (shape[1], shape[2]) + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') + matrix = rng.random(matrix2_shape) + matrix = bm.as_jax(matrix) + weight = 1. + + if events_type == 'float': + matrix = matrix.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time19 = time.time() + + result1 = result + + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time21 = time.time() + + result2 = result + + time22 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time39 = time.time() + + csrmm_time1 = (time1 - time0) * 1000 + csrmm_time2 = (time3 - time2) * 1000 + csrmm_time3 = (time5 - time4) * 1000 + csrmm_time4 = (time7 - time6) * 1000 + csrmm_time5 = (time9 - time8) * 1000 + csrmm_time6 = (time11 - time10) * 1000 + csrmm_time7 = (time13 - time12) * 1000 + csrmm_time8 = (time15 - time14) * 1000 + csrmm_time9 = (time17 - time16) * 1000 + csrmm_time10 = (time19 - time18) * 1000 + event_csrmm_time1 = (time21 - time20) * 1000 + event_csrmm_time2 = (time23 - time22) * 1000 + event_csrmm_time3 = (time25 - time24) * 1000 + event_csrmm_time4 = (time27 - time26) * 1000 + event_csrmm_time5 = (time29 - time28) * 1000 + event_csrmm_time6 = (time31 - time30) * 1000 + event_csrmm_time7 = (time33 - time32) * 1000 + event_csrmm_time8 = (time35 - time34) * 1000 + event_csrmm_time9 = (time37 - time36) * 1000 + event_csrmm_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('csrmm_1: ', csrmm_time1, 'ms') + print('csrmm_3: ', csrmm_time3, 'ms') + print('csrmm_5: ', csrmm_time5, 'ms') + print('csrmm_7: ', csrmm_time7, 'ms') + print('csrmm_9: ', csrmm_time9, 'ms') + print('event_csrmm_1: ', event_csrmm_time1, 'ms') + print('event_csrmm_3: ', event_csrmm_time3, 'ms') + print('event_csrmm_5: ', event_csrmm_time5, 'ms') + print('event_csrmm_7: ', event_csrmm_time7, 'ms') + print('event_csrmm_9: ', event_csrmm_time9, 'ms') + + r = bm.allclose(result1, result2) + if not r: + print('result1: ', result1) + print('result2: ', result2) + + return csrmm_time1, csrmm_time2, csrmm_time3, csrmm_time4, csrmm_time5, \ + csrmm_time6, csrmm_time7, csrmm_time8, csrmm_time9, csrmm_time10, \ + event_csrmm_time1, event_csrmm_time2, event_csrmm_time3, event_csrmm_time4, event_csrmm_time5, \ + event_csrmm_time6, event_csrmm_time7, event_csrmm_time8, event_csrmm_time9, event_csrmm_time10 + + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame( + columns=['shape', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', + 'csrmm time1(ms)', 'csrmm time2(ms)', 'csrmm time3(ms)', 'csrmm time4(ms)', + 'csrmm time5(ms)', + 'csrmm time6(ms)', 'csrmm time7(ms)', 'csrmm time8(ms)', 'csrmm time9(ms)', + 'csrmm time10(ms)', + 'event_csrmm time1(ms)', 'event_csrmm time2(ms)', 'event_csrmm time3(ms)', 'event_csrmm time4(ms)', + 'event_csrmm time5(ms)', + 'event_csrmm time6(ms)', 'event_csrmm time7(ms)', 'event_csrmm time8(ms)', 'event_csrmm time9(ms)', + 'event_csrmm time10(ms)']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, csrmm_time_5, \ + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, csrmm_time_10, \ + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, event_csrmm_time_5, \ + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, event_csrmm_time_10 = test_sparse_csrmm( + shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.05, shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, + _transpose, + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, + csrmm_time_5, + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, + csrmm_time_10, + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, + event_csrmm_time_5, + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, + event_csrmm_time_10] + df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, csrmm_time_5, \ + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, csrmm_time_10, \ + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, event_csrmm_time_5, \ + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, event_csrmm_time_10 = test_sparse_csrmm( + shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.05, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, + _transpose, + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, + csrmm_time_5, + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, + csrmm_time_10, + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, + event_csrmm_time_5, + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, + event_csrmm_time_10] + df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) diff --git a/brainpy/_src/math/event/tests/test_event_csrmm.py b/brainpy/_src/math/event/tests/test_event_csrmm.py index c570d153..12a35ef3 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmm.py +++ b/brainpy/_src/math/event/tests/test_event_csrmm.py @@ -140,7 +140,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(bm.event.csrmm))( - homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -152,7 +152,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r3 = dense_f2(matrix.astype(float)) r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), + bm.asarray([homo_data]), indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index 14256cbc..13c9e1e2 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,6 +1,7 @@ # from ._coo_mv import * # from ._bsr_mv import * from .csr_mv import * +from .csr_mm import * from .utils import * from .bsr_mm import * from .jax_prim import * diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py similarity index 78% rename from brainpy/_src/math/sparse/_csr_mm.py rename to brainpy/_src/math/sparse/csr_mm.py index b5e21446..dba93a79 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -8,13 +8,14 @@ import brainpy.math as bm from jax import numpy as jnp from jax.interpreters import ad +from jax.core import Tracer from jax.experimental.sparse import csr from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo ti = import_taichi() @@ -32,7 +33,8 @@ def csrmm( shape: Tuple[int, int], transpose: bool = False, ): - """Product of CSR sparse matrix and a dense matrix. + """ + Product of CSR sparse matrix and a dense matrix. Args: data : array of shape ``(nse,)``. @@ -46,7 +48,7 @@ def csrmm( Returns: C : array of shape ``(shape[1] if transpose else shape[0], cols)`` - representing the matrix-matrix product product. + representing the matrix-matrix product. """ return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] @@ -99,7 +101,7 @@ def raw_csrmm_taichi( return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] else: if transpose: - prim = _csr_matmat_transpose_heter_p + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] else: prim = _csr_matmat_heter_p else: @@ -124,16 +126,15 @@ def _csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val * matrix[row_j, col_i] + out[row_k, col_i] = r @ti.kernel @@ -142,12 +143,11 @@ def _csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * matrix[col_indices[j], col_k] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * matrix[col_indices[j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -157,15 +157,14 @@ def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -175,12 +174,11 @@ def _csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value # GPU kernels @@ -191,16 +189,15 @@ def _csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val * matrix[row_j, col_i] + out[row_k, col_i] = r @ti.kernel @@ -209,12 +206,11 @@ def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * matrix[col_indices[j], col_k] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * matrix[col_indices[j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -224,15 +220,14 @@ def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -242,12 +237,11 @@ def _csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value def _csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py new file mode 100644 index 00000000..79c8bef0 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py @@ -0,0 +1,258 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +s = [1000, 5000, 10000, 15000, 20000, 25000, 30000] +p = [0.1, 0.2, 0.3, 0.4, 0.5] + +size = [ + (100, 100, 100), + (100, 1000, 100), + (1000, 1000, 100), + (1000, 1000, 1000), + (100, 10000, 100), + (10000, 100, 1000), + (1000, 100, 10000), + (10000, 10000, 1000), + (10000, 1000, 10000), + (10000, 10000, 10000), + (20000, 20000, 20000), +] + +values_type = [ + 'heter' + ] +events_type = ['float'] +transpose = [ + True, + # False + ] + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +print(bm.get_platform()) + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmm_taichi(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method=None) + return r + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method='cusparse') + return r + +def test_sparse_csrmm(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) + matrix2_shape = (shape[1], shape[2]) + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') + matrix = rng.random(matrix2_shape) + matrix = bm.as_jax(matrix) + weight = 1. + + + + if events_type == 'float': + matrix = matrix.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time19 = time.time() + + result1 = result + + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time21 = time.time() + + result2 = result + + time22 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time39 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') + print(bm.allclose(result1, result2)) + + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5 , shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5 , shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py index c05b45cd..e4346c84 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmm.py +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -133,7 +133,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(bm.sparse.csrmm))( - homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -145,7 +145,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r3 = dense_f2(matrix.astype(float)) r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), + bm.asarray([homo_data]), indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 02e98b8f..3b4b5ed1 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,3 +1,4 @@ from brainpy._src.math.event import ( csrmv as csrmv, + csrmm as csrmm, )