From b06e817f99903731f3f9a57942a7184954bf32fe Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 30 Jan 2024 17:39:47 +0800 Subject: [PATCH 01/23] [math] Add sparse `matrix@matrix` operators --- brainpy/_src/math/sparse/__init__.py | 1 + brainpy/_src/math/sparse/_csr_mm.py | 265 +++++++++++++++++++ brainpy/_src/math/sparse/tests/test_csrmm.py | 158 +++++++++++ brainpy/math/sparse.py | 1 + 4 files changed, 425 insertions(+) create mode 100644 brainpy/_src/math/sparse/_csr_mm.py create mode 100644 brainpy/_src/math/sparse/tests/test_csrmm.py diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d45f2c80b..a5f383204 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,6 +1,7 @@ from ._coo_mv import * from ._csr_mv import * +from ._csr_mm import * from ._utils import * from ._bsr_mv import * from ._bsr_mm import * diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py new file mode 100644 index 000000000..82b0506d4 --- /dev/null +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- + + +from functools import partial +from typing import Union, Tuple + +import jax +import numba +import numpy as np +from jax import core, dtypes +from jax import numpy as jnp +from jax.interpreters import ad, mlir, xla +from jax.lib import xla_client +from jaxlib import gpu_sparse + +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi +from brainpy._src.math.interoperability import as_jax +from brainpy._src.math.ndarray import Array +from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, + register_general_batching, + XLACustomOp) +from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy.errors import GPUOperatorNotFound + +ti = import_taichi() + +__all__ = [ + 'csrmm', +] + + +def csrmm( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int, int], + transpose: bool = False, +): + print('data: ', data) + data = jnp.atleast_1d(as_jax(data)) + print('data: ', data) + indices = as_jax(indices) + indptr = as_jax(indptr) + matrix = as_jax(matrix) + + if matrix.dtype == jnp.bool_: + matrix = as_jax(matrix, dtype=data.dtype) + + if data.dtype != matrix.dtype: + raise TypeError('The types of data and vector should be the same. ' + f'But we got {data.dtype} != {matrix.dtype}.') + assert data.ndim == indices.ndim == indptr.ndim == 1 + assert matrix.ndim == 2 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + + # if the shape of indices is (0,), then we return a zero vector + if indices.shape[0] == 0: + return jnp.zeros((shape[2], shape[0]) if transpose else (shape[2], shape[1]), dtype=data.dtype) + + return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + + +def raw_csrmm_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int, int], + transpose: bool = False, +): + out_shape = (shape[2], shape[0]) if transpose else (shape[2], shape[0]) + if transpose: + if data.shape[0] == 1: + prim = _csr_matmat_transpose_homo_p + else: + prim = _csr_matmat_transpose_heter_p + else: + if data.shape[0] == 1: + prim = _csr_matmat_homo_p + else: + prim = _csr_matmat_heter_p + return prim(data, + indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(out_shape, dtype=data.dtype)], + transpose=transpose, + shape=shape) + + +# CPU kernels + +@ti.kernel +def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + for k in range(matrix.shape[0]): + out[k, col_indices[j]] += value * matrix[k, row_i] + + +@ti.kernel +def _csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + for k in range(matrix.shape[0]): + out[k, col_indices[j]] += values[j] * matrix[k, row_i] + + +@ti.kernel +def _csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[j], col_k] + out[row_i, col_k] = r * value + + +@ti.kernel +def _csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * matrix[col_indices[j], col_k] + out[row_i, col_k] = r + + +# GPU kernels + +@ti.kernel +def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + for k in range(matrix.shape[0]): + out[k, col_indices[j]] += value * matrix[k, row_i] + + +@ti.kernel +def _csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(row_ptr.shape[0] - 1): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + for k in range(matrix.shape[0]): + out[k, col_indices[j]] += values[j] * matrix[k, row_i] + + +@ti.kernel +def _csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[j], col_k] + out[row_i, col_k] = r * value + + +@ti.kernel +def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * matrix[col_indices[j], col_k] + out[row_i, col_k] = r + + +def _csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): + return raw_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) + + +def _csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): + return raw_csrmm_taichi(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) + + +def _csr_matmat_transpose( + ct, data, indices, indptr, matrix, *, outs, transpose, shape, +): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(matrix): + ct_matrix = raw_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = raw_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] + ct_data = jnp.inner(ct[0], ct_data) + else: + row, col = csr_to_coo(indices, indptr) + ct_data = jnp.zeros_like(data) + for i, j in zip(row, col): + if transpose: + ct_data[i] += jnp.sum(ct[0][:, j] * matrix[:, i]) + else: + ct_data[i] += jnp.sum(ct[0][j, :] * matrix[i, :]) + return ct_data, indices, indptr, matrix + + +def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_csr_matmat_jvp_values, None, None, _csr_matmat_jvp_matrix) + prim.def_transpose_rule(_csr_matmat_transpose) + return prim + + +# transpose homo +_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) + +# no transpose homo +_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo_cpu, + gpu_kernel=_csr_matmat_homo_gpu) + +# transpose heter +_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, + gpu_kernel=_csr_matmat_transpose_heter_gpu) + +# no transpose heter +_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, + gpu_kernel=_csr_matmat_heter_gpu) diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py new file mode 100644 index 000000000..612e5f6c3 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- + +from functools import partial + +import jax +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +# bm.set_platform('gpu') + +seed = 1234 + + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +class Test_csrmm(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_csrmm, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 0., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + print(f'test_homo: transpose: {transpose} shape = {shape}, homo_data = {homo_data}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[2], shape[1]) if transpose else (shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + heter_data = bm.ones(indices.shape).value * homo_data + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (matrix @ dense) if transpose else (dense @ matrix) + r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, shape=shape, transpose=transpose) + c = bm.allclose(r1, r2) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 0., 1.] + ) + def test_homo_vmap(self, transpose, shape, homo_data): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}, homo_data = {homo_data}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[2], shape[1]) if transpose else (shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + heter_data = bm.ones((10, indices.shape[0])).value * homo_data + homo_data = bm.ones(10).value * homo_data + dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])))(heter_data) + + f1 = lambda a: (matrix @ a) if transpose else (a @ matrix) + f2 = partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=shape, transpose=transpose) + r1 = jax.vmap(f1)(dense) + r2 = jax.vmap(f2)(homo_data) + + self.assertTrue(bm.allclose(r1, r2)) + + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 0., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}, homo_data = {homo_data}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[2], shape[1]) if transpose else (shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: ((matrix @ (dense * a)).sum() + if transpose else + ((dense * a) @ matrix).sum()), + argnums=0) + r1 = dense_f1(homo_data) + r2 = jax.grad(sum_op(bm.sparse.csrmm))( + homo_data, indices, indptr, matrix, shape=shape, transpose=transpose + ) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad matrix + dense_data = dense * homo_data + dense_f2 = jax.grad(lambda m: ((m @ dense_data).sum() + if transpose else + (dense_data @ m).sum())) + r3 = dense_f2(matrix) + r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + homo_data, indices, indptr, matrix.astype(float), shape=shape, transpose=transpose + ) + + self.assertTrue(bm.allclose(r3, r4)) + + # grad both + dense_f3 = jax.grad(lambda a, m: ((m @ (dense * a)).sum() + if transpose else + ((dense * a) @ m).sum()), + argnums=(0, 1)) + r5 = dense_f3(homo_data, matrix) + r6 = jax.grad(sum_op(bm.sparse.csrmm), argnums=(0, 3))( + homo_data, indices, indptr, matrix.astype(float), shape=shape, transpose=transpose + ) + + self.assertTrue(bm.allclose(r5[0], r6[0])) + self.assertTrue(bm.allclose(r5[1], r6[1])) + + bm.clear_buffer_memory() diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 1380a9e9c..3a78c5252 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,5 +1,6 @@ from brainpy._src.math.sparse import ( csrmv, + csrmm, coomv, seg_matmul, From 90516c1cd709168a1e53412f83b6122619242df2 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 1 Feb 2024 22:41:01 +0800 Subject: [PATCH 02/23] [math] Implement csr matrix @ matrix operator --- brainpy/_src/math/sparse/_csr_mm.py | 167 ++++++------------- brainpy/_src/math/sparse/tests/test_csrmm.py | 81 ++++----- 2 files changed, 91 insertions(+), 157 deletions(-) diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index 82b0506d4..f104a4a4e 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -35,12 +35,39 @@ def csrmm( indptr: Union[jnp.ndarray, Array], matrix: Union[jnp.ndarray, Array], *, - shape: Tuple[int, int, int], + shape: Tuple[int, int], transpose: bool = False, ): - print('data: ', data) - data = jnp.atleast_1d(as_jax(data)) - print('data: ', data) + """Product of CSR sparse matrix and a dense matrix. + + Args: + data : array of shape ``(nse,)``. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix-matrix product product. + """ + return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + + +def raw_csrmm_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + assert len(shape) == 2 + indices = as_jax(indices) indptr = as_jax(indptr) matrix = as_jax(matrix) @@ -53,87 +80,51 @@ def csrmm( f'But we got {data.dtype} != {matrix.dtype}.') assert data.ndim == indices.ndim == indptr.ndim == 1 assert matrix.ndim == 2 + assert data.shape == indices.shape + assert indptr.shape[0] == shape[0] + 1 if not jnp.issubdtype(indices.dtype, jnp.integer): raise ValueError('indices should be a 1D vector with integer type.') if not jnp.issubdtype(indptr.dtype, jnp.integer): raise ValueError('indptr should be a 1D vector with integer type.') - # if the shape of indices is (0,), then we return a zero vector + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + # if the shape of indices is (0,), then we return a zero matrix if indices.shape[0] == 0: - return jnp.zeros((shape[2], shape[0]) if transpose else (shape[2], shape[1]), dtype=data.dtype) - - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + return [jnp.zeros(result_shape, dtype=data.dtype),] - -def raw_csrmm_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int, int], - transpose: bool = False, -): - out_shape = (shape[2], shape[0]) if transpose else (shape[2], shape[0]) - if transpose: - if data.shape[0] == 1: - prim = _csr_matmat_transpose_homo_p - else: - prim = _csr_matmat_transpose_heter_p + assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + if transpose: + prim = _csr_matmat_transpose_heter_p else: - if data.shape[0] == 1: - prim = _csr_matmat_homo_p - else: - prim = _csr_matmat_heter_p + prim = _csr_matmat_heter_p return prim(data, indices, indptr, matrix, - outs=[jax.ShapeDtypeStruct(out_shape, dtype=data.dtype)], + outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], transpose=transpose, shape=shape) # CPU kernels -@ti.kernel -def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - for k in range(matrix.shape[0]): - out[k, col_indices[j]] += value * matrix[k, row_i] - - @ti.kernel def _csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - for k in range(matrix.shape[0]): - out[k, col_indices[j]] += values[j] * matrix[k, row_i] - - -@ti.kernel -def _csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[j], col_k] - out[row_i, col_k] = r * value + for row_j in range(matrix.shape[0]): + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val * matrix[row_j, col_i] + out[row_k, col_i] = r @ti.kernel @@ -152,19 +143,6 @@ def _csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), # GPU kernels -@ti.kernel -def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - for k in range(matrix.shape[0]): - out[k, col_indices[j]] += value * matrix[k, row_i] - - @ti.kernel def _csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), @@ -177,21 +155,6 @@ def _csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), out[k, col_indices[j]] += values[j] * matrix[k, row_i] -@ti.kernel -def _csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[j], col_k] - out[row_i, col_k] = r * value - - @ti.kernel def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), @@ -221,24 +184,12 @@ def _csr_matmat_transpose( raise ValueError("Cannot transpose with respect to sparse indices.") if ad.is_undefined_primal(matrix): ct_matrix = raw_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) + return data, indices, indptr, ct_matrix else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] - ct_data = jnp.inner(ct[0], ct_data) - else: - row, col = csr_to_coo(indices, indptr) - ct_data = jnp.zeros_like(data) - for i, j in zip(row, col): - if transpose: - ct_data[i] += jnp.sum(ct[0][:, j] * matrix[:, i]) - else: - ct_data[i] += jnp.sum(ct[0][j, :] * matrix[i, :]) - return ct_data, indices, indptr, matrix + matrix = jnp.asarray(matrix) + row, col = csr_to_coo(indices, indptr) + return (ct[0][row] * matrix[col]).sum(1), indices, indptr, matrix def _define_op(cpu_kernel, gpu_kernel): @@ -248,14 +199,6 @@ def _define_op(cpu_kernel, gpu_kernel): return prim -# transpose homo -_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, - gpu_kernel=_csr_matmat_transpose_homo_gpu) - -# no transpose homo -_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo_cpu, - gpu_kernel=_csr_matmat_homo_gpu) - # transpose heter _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, gpu_kernel=_csr_matmat_transpose_heter_gpu) diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py index 612e5f6c3..cec9bde5b 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmm.py +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -31,10 +31,9 @@ def __init__(self, *args, platform='cpu', **kwargs): @parameterized.product( transpose=[True, False], shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 0., 1.] ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: transpose: {transpose} shape = {shape}, homo_data = {homo_data}') + def test_heter(self, transpose, shape): + print(f'test_homo: transpose: {transpose} shape = {shape}') conn = bp.conn.FixedProb(0.3) # csr matrix @@ -45,16 +44,17 @@ def test_homo(self, transpose, shape, homo_data): indptr = bm.as_jax(indptr) # matrix rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[2], shape[1]) if transpose else (shape[1], shape[2])) + matrix = rng.random((shape[1], shape[2])) matrix = bm.as_jax(matrix) - heter_data = bm.ones(indices.shape).value * homo_data + heter_data = bm.as_jax(rng.random(indices.shape)) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - r1 = (matrix @ dense) if transpose else (dense @ matrix) - r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, shape=shape, transpose=transpose) - c = bm.allclose(r1, r2) + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + c = bm.allclose(r1, r2, equal_nan=True) if not c: print(r1 - r2) self.assertTrue(c) @@ -64,10 +64,9 @@ def test_homo(self, transpose, shape, homo_data): @parameterized.product( transpose=[True, False], shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 0., 1.] ) - def test_homo_vmap(self, transpose, shape, homo_data): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}, homo_data = {homo_data}') + def test_heter_vmap(self, transpose, shape): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') conn = bp.conn.FixedProb(0.3) # csr matrix @@ -78,29 +77,28 @@ def test_homo_vmap(self, transpose, shape, homo_data): indptr = bm.as_jax(indptr) # matrix rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[2], shape[1]) if transpose else (shape[1], shape[2])) + matrix = rng.random((shape[1], shape[2])) matrix = bm.as_jax(matrix) - heter_data = bm.ones((10, indices.shape[0])).value * homo_data - homo_data = bm.ones(10).value * homo_data + heter_data = bm.as_jax(rng.random((10, indices.shape[0]))) dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])))(heter_data) - f1 = lambda a: (matrix @ a) if transpose else (a @ matrix) + f1 = lambda a: (a.T @ matrix) if transpose else (a @ matrix) f2 = partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=shape, transpose=transpose) + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) r1 = jax.vmap(f1)(dense) - r2 = jax.vmap(f2)(homo_data) + r2 = jax.vmap(f2)(heter_data) - self.assertTrue(bm.allclose(r1, r2)) + self.assertTrue(bm.allclose(r1, r2, equal_nan=True)) @parameterized.product( transpose=[True, False], shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 0., 1.] ) - def test_homo_grad(self, transpose, shape, homo_data): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}, homo_data = {homo_data}') + def test_heter_grad(self, transpose, shape): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) conn = bp.conn.FixedProb(0.3) # csr matrix @@ -109,50 +107,43 @@ def test_homo_grad(self, transpose, shape, homo_data): 'pre2post') indices = bm.as_jax(indices) indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + + heter_data = bm.as_jax(rng.random((indices.shape))) + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[2], shape[1]) if transpose else (shape[1], shape[2])) + matrix = rng.random((shape[1], shape[2])) matrix = bm.as_jax(matrix) # grad data - dense_f1 = jax.grad(lambda a: ((matrix @ (dense * a)).sum() + dense_f1 = jax.grad(lambda a: ((a.T @ matrix).sum() if transpose else - ((dense * a) @ matrix).sum()), + (a @ matrix).sum()), argnums=0) - r1 = dense_f1(homo_data) + r1 = dense_f1(dense) r2 = jax.grad(sum_op(bm.sparse.csrmm))( - homo_data, indices, indptr, matrix, shape=shape, transpose=transpose + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose ) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + if transpose: + r1 = r1[cols, rows] + else: + r1 = r1[rows, cols] + print(r1 - r2) self.assertTrue(bm.allclose(r1, r2)) # grad matrix - dense_data = dense * homo_data - dense_f2 = jax.grad(lambda m: ((m @ dense_data).sum() + dense_f2 = jax.grad(lambda m: ((dense.T @ m).sum() if transpose else - (dense_data @ m).sum())) + (dense @ m).sum())) r3 = dense_f2(matrix) r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), shape=shape, transpose=transpose + heter_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose ) self.assertTrue(bm.allclose(r3, r4)) - # grad both - dense_f3 = jax.grad(lambda a, m: ((m @ (dense * a)).sum() - if transpose else - ((dense * a) @ m).sum()), - argnums=(0, 1)) - r5 = dense_f3(homo_data, matrix) - r6 = jax.grad(sum_op(bm.sparse.csrmm), argnums=(0, 3))( - homo_data, indices, indptr, matrix.astype(float), shape=shape, transpose=transpose - ) - - self.assertTrue(bm.allclose(r5[0], r6[0])) - self.assertTrue(bm.allclose(r5[1], r6[1])) - bm.clear_buffer_memory() From 787d6a83cb90fc6ba715288cbf696ca7a9fa101a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 1 Feb 2024 22:44:39 +0800 Subject: [PATCH 03/23] Update _csr_mm.py --- brainpy/_src/math/sparse/_csr_mm.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index f104a4a4e..1e474962d 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -149,10 +149,16 @@ def _csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - for k in range(matrix.shape[0]): - out[k, col_indices[j]] += values[j] * matrix[k, row_i] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val * matrix[row_j, col_i] + out[row_k, col_i] = r @ti.kernel From 5de7807f0e4323eafa44070e6217a39f3ff8d13a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 3 Feb 2024 23:29:22 +0800 Subject: [PATCH 04/23] [math] Implement event csr matmat --- brainpy/_src/math/event/__init__.py | 1 + brainpy/_src/math/event/_csr_matmat.py | 461 ++++++++++++++++++ .../_src/math/event/tests/test_event_csrmm.py | 270 ++++++++++ brainpy/_src/math/sparse/_csr_mm.py | 82 +++- brainpy/math/event.py | 1 + 5 files changed, 811 insertions(+), 4 deletions(-) create mode 100644 brainpy/_src/math/event/_csr_matmat.py create mode 100644 brainpy/_src/math/event/tests/test_event_csrmm.py diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 631129558..7b76dbfeb 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,4 +1,5 @@ from ._info_collection import * from ._csr_matvec import * +from ._csr_matmat import * diff --git a/brainpy/_src/math/event/_csr_matmat.py b/brainpy/_src/math/event/_csr_matmat.py new file mode 100644 index 000000000..8fc48be92 --- /dev/null +++ b/brainpy/_src/math/event/_csr_matmat.py @@ -0,0 +1,461 @@ +# -*- coding: utf-8 -*- + + +from functools import partial +from typing import Union, Tuple + +import jax +import numba +import numpy as np +from jax import core, dtypes +from jax import numpy as jnp +from jax.interpreters import ad, mlir, xla +from jax.lib import xla_client +from jaxlib import gpu_sparse + +from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi +from brainpy._src.math.interoperability import as_jax +from brainpy._src.math.ndarray import Array +from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, + register_general_batching, + XLACustomOp) +from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse._csr_mm import raw_csrmm_taichi as normal_csrmv +from brainpy.errors import GPUOperatorNotFound + +ti = import_taichi() + +__all__ = [ + 'csrmm', +] + + +def csrmm( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + """Product of CSR sparse matrix and a dense event matrix. + + Args: + data : array of shape ``(nse,)``, float. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix-matrix product product. + """ + return raw_event_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + + +def raw_event_csrmm_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + assert len(shape) == 2 + + data = jnp.atleast_1d(data) + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + + indices = as_jax(indices) + indptr = as_jax(indptr) + matrix = as_jax(matrix) + + assert data.ndim == indices.ndim == indptr.ndim == 1 + assert matrix.ndim == 2 + assert indptr.shape[0] == shape[0] + 1 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + # if the shape of indices is (0,), then we return a zero matrix + if indices.shape[0] == 0: + return [jnp.zeros(result_shape, dtype=data.dtype),] + + assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + if transpose: + if matrix.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csr_matmat_transpose_bool_homo_p + else: + prim = _event_csr_matmat_transpose_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csr_matmat_transpose_homo_p + else: + prim = _event_csr_matmat_transpose_heter_p + else: + if matrix.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csr_matmat_bool_homo_p + else: + prim = _event_csr_matmat_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csr_matmat_homo_p + else: + prim = _event_csr_matmat_heter_p + return prim(data, + indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], + transpose=transpose, + shape=shape) + + +# CPU kernels + +@ti.kernel +def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r +@ti.kernel +def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r + +@ti.kernel +def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] + out[row_i, col_k] = r + +@ti.kernel +def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] + out[row_i, col_k] = r + +@ti.kernel +def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r + +@ti.kernel +def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r + +@ti.kernel +def _event_csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value + +@ti.kernel +def _event_csr_matmat_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value + +# GPU kernels + +@ti.kernel +def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r + +@ti.kernel +def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r + +@ti.kernel +def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] + out[row_i, col_k] = r + + +@ti.kernel +def _event_csr_matmat_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] + out[row_i, col_k] = r + +@ti.kernel +def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r + +@ti.kernel +def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r + +@ti.kernel +def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value + +@ti.kernel +def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value + + +def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): + return raw_event_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) + + +def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): + return raw_event_csrmm_taichi(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) + + +def _event_csr_matmat_transpose( + ct, data, indices, indptr, matrix, *, outs, transpose, shape, +): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(matrix): + ct_matrix = raw_event_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] + ct_data = jnp.sum(ct[0] * ct_data) + else: # heter + matrix = jnp.asarray(matrix) + row, col = csr_to_coo(indices, indptr) + ct_data = (ct[0][row] * matrix[col]).sum(1) + return ct_data, indices, indptr, matrix + + +def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matmat_jvp_values, None, None, _event_csr_matmat_jvp_matrix) + prim.def_transpose_rule(_event_csr_matmat_transpose) + return prim + + +# transpose heter +_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter_cpu, + gpu_kernel=_event_csr_matmat_transpose_heter_gpu) + +# no transpose heter +_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter_cpu, + gpu_kernel=_event_csr_matmat_heter_gpu) + +# transpose homo +_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo_cpu, + gpu_kernel=_event_csr_matmat_transpose_homo_gpu) + +# no transpose homo +_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo_cpu, + gpu_kernel=_event_csr_matmat_homo_gpu) + +# bool transpose heter +_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter_cpu, + gpu_kernel=_event_csr_matmat_transpose_bool_heter_gpu) + +# bool no transpose heter +_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter_cpu, + gpu_kernel=_event_csr_matmat_bool_heter_gpu) + +# bool transpose homo +_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo_cpu, + gpu_kernel=_event_csr_matmat_transpose_bool_homo_gpu) + +# bool no transpose homo +_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo_cpu, + gpu_kernel=_event_csr_matmat_bool_homo_gpu) \ No newline at end of file diff --git a/brainpy/_src/math/event/tests/test_event_csrmm.py b/brainpy/_src/math/event/tests/test_event_csrmm.py new file mode 100644 index 000000000..e555a9214 --- /dev/null +++ b/brainpy/_src/math/event/tests/test_event_csrmm.py @@ -0,0 +1,270 @@ +# -*- coding: utf-8 -*- + +from functools import partial + +import jax +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +# bm.set_platform('gpu') + +seed = 1234 + + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +class Test_csrmm(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_csrmm, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones(indices.shape) * homo_data + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.event.csrmm(homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + c = bm.allclose(r1, r2, equal_nan=True) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_vmap(self, transpose, shape, homo_data): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # vmap 'data' + f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax([homo_data] * 10) + heter_data = bm.ones((10, indices.shape[0])) * homo_data + r1 = f1(heter_data) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + # vmap 'events' + heter_data = bm.ones(indices.shape) * homo_data + f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f4 = jax.vmap(partial(bm.event.csrmm, homo_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) + r3 = f3(matrix) + r4 = f4(matrix) + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() + if transpose else + ((dense * a) @ matrix).sum()), + argnums=0) + r1 = dense_f1(homo_data) + r2 = jax.grad(sum_op(bm.event.csrmm))( + homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() + if transpose else + ((dense * homo_data) @ m).sum()), + argnums=0) + r3 = dense_f2(matrix.astype(float)) + r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( + homo_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter(self, transpose, shape): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.as_jax(rng.random(indices.shape)) + + r1 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r2 = bm.event.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_vmap(self, transpose, shape): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # vmap 'data' + f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) + r1 = f1(vmap_data) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + # vmap 'events' + heter_data = bm.ones(indices.shape) + f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f4 = jax.vmap(partial(bm.event.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) + r3 = f3(matrix) + r4 = f4(matrix) + self.assertTrue(bm.allclose(r3, r4)) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_grad(self, transpose, shape): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + r1 = jax.grad(sum_op(bm.sparse.csrmm))( + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r2 = jax.grad(sum_op(bm.event.csrmm))( + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + r3 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + heter_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( + heter_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index 1e474962d..2dc2c0c73 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -30,7 +30,7 @@ def csrmm( - data: Union[float, jnp.ndarray, Array], + data: Union[jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], indptr: Union[jnp.ndarray, Array], matrix: Union[jnp.ndarray, Array], @@ -80,7 +80,7 @@ def raw_csrmm_taichi( f'But we got {data.dtype} != {matrix.dtype}.') assert data.ndim == indices.ndim == indptr.ndim == 1 assert matrix.ndim == 2 - assert data.shape == indices.shape + assert data.shape == indices.shape or data.shape[0] == 1 assert indptr.shape[0] == shape[0] + 1 if not jnp.issubdtype(indices.dtype, jnp.integer): raise ValueError('indices should be a 1D vector with integer type.') @@ -95,9 +95,15 @@ def raw_csrmm_taichi( assert matrix.shape[0] == (shape[0] if transpose else shape[1]) if transpose: - prim = _csr_matmat_transpose_heter_p + if data.shape[0] == 1: + prim = _csr_matmat_transpose_homo_p + else: + prim = _csr_matmat_transpose_heter_p else: - prim = _csr_matmat_heter_p + if data.shape[0] == 1: + prim = _csr_matmat_homo_p + else: + prim = _csr_matmat_heter_p return prim(data, indices, indptr, @@ -140,6 +146,36 @@ def _csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), r += values[j] * matrix[col_indices[j], col_k] out[row_i, col_k] = r +@ti.kernel +def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r + +@ti.kernel +def _csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value # GPU kernels @@ -174,6 +210,36 @@ def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), r += values[j] * matrix[col_indices[j], col_k] out[row_i, col_k] = r +@ti.kernel +def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r + +@ti.kernel +def _csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i in range(row_ptr.shape[0] - 1): + for col_k in range(matrix.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value def _csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): return raw_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) @@ -212,3 +278,11 @@ def _define_op(cpu_kernel, gpu_kernel): # no transpose heter _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, gpu_kernel=_csr_matmat_heter_gpu) + +# transpose homo +_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) + +# no transpose homo +_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo_cpu, + gpu_kernel=_csr_matmat_homo_gpu) \ No newline at end of file diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 0a17cae7c..0d5da9973 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,5 +1,6 @@ from brainpy._src.math.event import ( csrmv as csrmv, + csrmm as csrmm, info as info, ) From 8b403c5b25dfff1c8fd8233d4713b6dd055ee23b Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 4 Feb 2024 14:07:57 +0800 Subject: [PATCH 05/23] [math] Support data is homo for csr matmat op --- brainpy/_src/math/sparse/_csr_mm.py | 26 +++- brainpy/_src/math/sparse/tests/test_csrmm.py | 120 +++++++++++++++++++ 2 files changed, 140 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index 2dc2c0c73..075270790 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -30,7 +30,7 @@ def csrmm( - data: Union[jnp.ndarray, Array], + data: Union[float, jnp.ndarray, Array], indices: Union[jnp.ndarray, Array], indptr: Union[jnp.ndarray, Array], matrix: Union[jnp.ndarray, Array], @@ -72,6 +72,8 @@ def raw_csrmm_taichi( indptr = as_jax(indptr) matrix = as_jax(matrix) + data = jnp.atleast_1d(data) + if matrix.dtype == jnp.bool_: matrix = as_jax(matrix, dtype=data.dtype) @@ -80,7 +82,11 @@ def raw_csrmm_taichi( f'But we got {data.dtype} != {matrix.dtype}.') assert data.ndim == indices.ndim == indptr.ndim == 1 assert matrix.ndim == 2 - assert data.shape == indices.shape or data.shape[0] == 1 + data = jnp.atleast_1d(data) + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') assert indptr.shape[0] == shape[0] + 1 if not jnp.issubdtype(indices.dtype, jnp.integer): raise ValueError('indices should be a 1D vector with integer type.') @@ -256,12 +262,20 @@ def _csr_matmat_transpose( raise ValueError("Cannot transpose with respect to sparse indices.") if ad.is_undefined_primal(matrix): ct_matrix = raw_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, ct_matrix + return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) else: - matrix = jnp.asarray(matrix) - row, col = csr_to_coo(indices, indptr) - return (ct[0][row] * matrix[col]).sum(1), indices, indptr, matrix + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = raw_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] + ct_data = jnp.sum(ct[0] * ct_data) + else: # heter + matrix = jnp.asarray(matrix) + row, col = csr_to_coo(indices, indptr) + ct_data = (ct[0][row] * matrix[col]).sum(1) + return ct_data, indices, indptr, matrix def _define_op(cpu_kernel, gpu_kernel): diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py index cec9bde5b..9cd7f0133 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmm.py +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -28,6 +28,126 @@ def __init__(self, *args, platform='cpu', **kwargs): print() bm.set_platform(platform) + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones(indices.shape) * homo_data + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + c = bm.allclose(r1, r2, equal_nan=True) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_vmap(self, transpose, shape, homo_data): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones((10, indices.shape[0])) * homo_data + dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])))(heter_data) + + # vmap 'data' + f1 = jax.vmap(lambda a: (a.T @ matrix) if transpose else (a @ matrix)) + f2 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax([homo_data] * 10) + + r1 = f1(dense) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() + if transpose else + ((dense * a) @ matrix).sum()), + argnums=0) + r1 = dense_f1(homo_data) + r2 = jax.grad(sum_op(bm.sparse.csrmm))( + homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() + if transpose else + ((dense * homo_data) @ m).sum()), + argnums=0) + r3 = dense_f2(matrix.astype(float)) + r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + homo_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + @parameterized.product( transpose=[True, False], shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], From eb09d0ea47be01a7a97282e9091fe02094e08cb2 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 4 Feb 2024 14:09:26 +0800 Subject: [PATCH 06/23] Update _csr_matmat.py --- brainpy/_src/math/event/_csr_matmat.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/brainpy/_src/math/event/_csr_matmat.py b/brainpy/_src/math/event/_csr_matmat.py index 8fc48be92..fbd2d2524 100644 --- a/brainpy/_src/math/event/_csr_matmat.py +++ b/brainpy/_src/math/event/_csr_matmat.py @@ -20,7 +20,7 @@ register_general_batching, XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy._src.math.sparse._csr_mm import raw_csrmm_taichi as normal_csrmv +from brainpy._src.math.sparse._csr_mm import raw_csrmm_taichi as normal_csrmm from brainpy.errors import GPUOperatorNotFound ti = import_taichi() @@ -391,11 +391,11 @@ def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1), def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return raw_event_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) + return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return raw_event_csrmm_taichi(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) + return normal_csrmm(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) def _event_csr_matmat_transpose( From a2124974efecb380a773e36143bf55035c3fa699 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 4 Feb 2024 14:27:56 +0800 Subject: [PATCH 07/23] Format codes --- brainpy/_src/math/event/_csr_matmat.py | 204 ++++++++++-------- .../_src/math/event/tests/test_event_csrmm.py | 28 ++- brainpy/_src/math/sparse/_csr_mm.py | 68 +++--- brainpy/_src/math/sparse/tests/test_csrmm.py | 34 +-- 4 files changed, 185 insertions(+), 149 deletions(-) diff --git a/brainpy/_src/math/event/_csr_matmat.py b/brainpy/_src/math/event/_csr_matmat.py index fbd2d2524..c3e37f67d 100644 --- a/brainpy/_src/math/event/_csr_matmat.py +++ b/brainpy/_src/math/event/_csr_matmat.py @@ -73,8 +73,8 @@ def raw_event_csrmm_taichi( if np.ndim(data) == 1: if data.shape[0] not in [1, indices.shape[0]]: raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + indices = as_jax(indices) indptr = as_jax(indptr) matrix = as_jax(matrix) @@ -91,7 +91,7 @@ def raw_event_csrmm_taichi( result_shape = (out_shape, matrix.shape[1]) # if the shape of indices is (0,), then we return a zero matrix if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype),] + return [jnp.zeros(result_shape, dtype=data.dtype), ] assert matrix.shape[0] == (shape[0] if transpose else shape[1]) if transpose: @@ -129,10 +129,10 @@ def raw_event_csrmm_taichi( @ti.kernel def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): r = 0. @@ -144,12 +144,14 @@ def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), val = values[j] r += val out[row_k, col_i] = r + + @ti.kernel def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): r = 0. @@ -162,40 +164,43 @@ def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), r += val out[row_k, col_i] = r + @ti.kernel def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): if matrix[col_indices[row_j], col_k] != 0.: r += values[row_j] out[row_i, col_k] = r + @ti.kernel def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): if matrix[col_indices[row_j], col_k]: r += values[row_j] out[row_i, col_k] = r + @ti.kernel def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): @@ -208,12 +213,13 @@ def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), break out[row_k, col_i] = r + @ti.kernel def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): @@ -226,44 +232,47 @@ def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), break out[row_k, col_i] = r + @ti.kernel def _event_csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): if matrix[col_indices[row_j], col_k] != 0.: r += matrix[col_indices[row_j], col_k] out[row_i, col_k] = r * value + @ti.kernel def _event_csr_matmat_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): if matrix[col_indices[row_j], col_k]: r += matrix[col_indices[row_j], col_k] out[row_i, col_k] = r * value + # GPU kernels @ti.kernel def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): r = 0. @@ -276,12 +285,13 @@ def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), r += val out[row_k, col_i] = r + @ti.kernel def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): r = 0. @@ -294,14 +304,15 @@ def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), r += val out[row_k, col_i] = r + @ti.kernel def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): if matrix[col_indices[row_j], col_k] != 0.: @@ -311,24 +322,25 @@ def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), @ti.kernel def _event_csr_matmat_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): if matrix[col_indices[row_j], col_k]: r += values[row_j] out[row_i, col_k] = r + @ti.kernel def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): @@ -341,12 +353,13 @@ def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), break out[row_k, col_i] = r + @ti.kernel def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): @@ -359,30 +372,32 @@ def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), break out[row_k, col_i] = r + @ti.kernel def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): if matrix[col_indices[row_j], col_k] != 0.: r += matrix[col_indices[row_j], col_k] out[row_i, col_k] = r * value + @ti.kernel def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): if matrix[col_indices[row_j], col_k]: @@ -411,10 +426,11 @@ def _event_csr_matmat_transpose( if type(ct[0]) is ad.Zero: ct_data = ad.Zero(data) else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] + if data.aval.shape[0] == 1: # scalar + ct_data = \ + raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] ct_data = jnp.sum(ct[0] * ct_data) - else: # heter + else: # heter matrix = jnp.asarray(matrix) row, col = csr_to_coo(indices, indptr) ct_data = (ct[0][row] * matrix[col]).sum(1) @@ -430,32 +446,32 @@ def _define_op(cpu_kernel, gpu_kernel): # transpose heter _event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter_cpu, - gpu_kernel=_event_csr_matmat_transpose_heter_gpu) + gpu_kernel=_event_csr_matmat_transpose_heter_gpu) # no transpose heter _event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter_cpu, - gpu_kernel=_event_csr_matmat_heter_gpu) + gpu_kernel=_event_csr_matmat_heter_gpu) # transpose homo _event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo_cpu, - gpu_kernel=_event_csr_matmat_transpose_homo_gpu) + gpu_kernel=_event_csr_matmat_transpose_homo_gpu) # no transpose homo _event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo_cpu, - gpu_kernel=_event_csr_matmat_homo_gpu) + gpu_kernel=_event_csr_matmat_homo_gpu) # bool transpose heter _event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter_cpu, - gpu_kernel=_event_csr_matmat_transpose_bool_heter_gpu) + gpu_kernel=_event_csr_matmat_transpose_bool_heter_gpu) # bool no transpose heter _event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter_cpu, - gpu_kernel=_event_csr_matmat_bool_heter_gpu) + gpu_kernel=_event_csr_matmat_bool_heter_gpu) # bool transpose homo _event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo_cpu, - gpu_kernel=_event_csr_matmat_transpose_bool_homo_gpu) + gpu_kernel=_event_csr_matmat_transpose_bool_homo_gpu) # bool no transpose homo _event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo_cpu, - gpu_kernel=_event_csr_matmat_bool_homo_gpu) \ No newline at end of file + gpu_kernel=_event_csr_matmat_bool_homo_gpu) diff --git a/brainpy/_src/math/event/tests/test_event_csrmm.py b/brainpy/_src/math/event/tests/test_event_csrmm.py index e555a9214..52e0378c6 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmm.py +++ b/brainpy/_src/math/event/tests/test_event_csrmm.py @@ -54,7 +54,8 @@ def test_homo(self, transpose, shape, homo_data): shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.event.csrmm(homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r2 = bm.event.csrmm(homo_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) c = bm.allclose(r1, r2, equal_nan=True) if not c: print(r1 - r2) @@ -106,7 +107,6 @@ def test_homo_vmap(self, transpose, shape, homo_data): bm.clear_buffer_memory() - @parameterized.product( transpose=[True, False], shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], @@ -137,10 +137,11 @@ def test_homo_grad(self, transpose, shape, homo_data): dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() if transpose else ((dense * a) @ matrix).sum()), - argnums=0) + argnums=0) r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(bm.event.csrmm))( - homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -148,10 +149,11 @@ def test_homo_grad(self, transpose, shape, homo_data): dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() if transpose else ((dense * homo_data) @ m).sum()), - argnums=0) + argnums=0) r3 = dense_f2(matrix.astype(float)) r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + homo_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -179,7 +181,7 @@ def test_heter(self, transpose, shape): heter_data = bm.as_jax(rng.random(indices.shape)) r1 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) r2 = bm.event.csrmm(heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) @@ -254,16 +256,20 @@ def test_heter_grad(self, transpose, shape): # grad data r1 = jax.grad(sum_op(bm.sparse.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) r2 = jax.grad(sum_op(bm.event.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) # grad events matrix r3 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index 075270790..eda72977a 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -86,7 +86,7 @@ def raw_csrmm_taichi( if np.ndim(data) == 1: if data.shape[0] not in [1, indices.shape[0]]: raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') assert indptr.shape[0] == shape[0] + 1 if not jnp.issubdtype(indices.dtype, jnp.integer): raise ValueError('indices should be a 1D vector with integer type.') @@ -97,10 +97,10 @@ def raw_csrmm_taichi( result_shape = (out_shape, matrix.shape[1]) # if the shape of indices is (0,), then we return a zero matrix if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype),] + return [jnp.zeros(result_shape, dtype=data.dtype), ] assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - if transpose: + if transpose: if data.shape[0] == 1: prim = _csr_matmat_transpose_homo_p else: @@ -145,19 +145,20 @@ def _csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for j in range(row_ptr[row_i], row_ptr[row_i + 1]): r += values[j] * matrix[col_indices[j], col_k] out[row_i, col_k] = r + @ti.kernel def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): @@ -169,20 +170,22 @@ def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), break out[row_k, col_i] = r + @ti.kernel def _csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): r += matrix[col_indices[row_j], col_k] out[row_i, col_k] = r * value + # GPU kernels @ti.kernel @@ -209,19 +212,20 @@ def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for j in range(row_ptr[row_i], row_ptr[row_i + 1]): r += values[j] * matrix[col_indices[j], col_k] out[row_i, col_k] = r + @ti.kernel def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for col_i in range(out.shape[1]): for row_k in range(out.shape[0]): @@ -233,20 +237,22 @@ def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), break out[row_k, col_i] = r + @ti.kernel def _csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for col_k in range(matrix.shape[1]): + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): r += matrix[col_indices[row_j], col_k] out[row_i, col_k] = r * value + def _csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): return raw_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) @@ -268,10 +274,10 @@ def _csr_matmat_transpose( if type(ct[0]) is ad.Zero: ct_data = ad.Zero(data) else: - if data.aval.shape[0] == 1: # scalar + if data.aval.shape[0] == 1: # scalar ct_data = raw_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] ct_data = jnp.sum(ct[0] * ct_data) - else: # heter + else: # heter matrix = jnp.asarray(matrix) row, col = csr_to_coo(indices, indptr) ct_data = (ct[0][row] * matrix[col]).sum(1) @@ -295,8 +301,8 @@ def _define_op(cpu_kernel, gpu_kernel): # transpose homo _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, - gpu_kernel=_csr_matmat_transpose_homo_gpu) + gpu_kernel=_csr_matmat_transpose_homo_gpu) # no transpose homo _csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo_cpu, - gpu_kernel=_csr_matmat_homo_gpu) \ No newline at end of file + gpu_kernel=_csr_matmat_homo_gpu) diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py index 9cd7f0133..bb006b1db 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmm.py +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -54,7 +54,8 @@ def test_homo(self, transpose, shape, homo_data): shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) c = bm.allclose(r1, r2, equal_nan=True) if not c: print(r1 - r2) @@ -83,21 +84,22 @@ def test_homo_vmap(self, transpose, shape, homo_data): matrix = bm.as_jax(matrix) heter_data = bm.ones((10, indices.shape[0])) * homo_data - dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])))(heter_data) + dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, + shape=(shape[1], shape[0]) if transpose else ( + shape[0], shape[1])))(heter_data) # vmap 'data' f1 = jax.vmap(lambda a: (a.T @ matrix) if transpose else (a @ matrix)) f2 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) vmap_data = bm.as_jax([homo_data] * 10) - + r1 = f1(dense) r2 = f2(vmap_data) self.assertTrue(bm.allclose(r1, r2)) bm.clear_buffer_memory() - @parameterized.product( transpose=[True, False], shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], @@ -128,10 +130,11 @@ def test_homo_grad(self, transpose, shape, homo_data): dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() if transpose else ((dense * a) @ matrix).sum()), - argnums=0) + argnums=0) r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(bm.sparse.csrmm))( - homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -139,10 +142,11 @@ def test_homo_grad(self, transpose, shape, homo_data): dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() if transpose else ((dense * homo_data) @ m).sum()), - argnums=0) + argnums=0) r3 = dense_f2(matrix.astype(float)) r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + homo_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) @@ -173,7 +177,8 @@ def test_heter(self, transpose, shape): shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) c = bm.allclose(r1, r2, equal_nan=True) if not c: print(r1 - r2) @@ -201,7 +206,9 @@ def test_heter_vmap(self, transpose, shape): matrix = bm.as_jax(matrix) heter_data = bm.as_jax(rng.random((10, indices.shape[0]))) - dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])))(heter_data) + dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, + shape=(shape[1], shape[0]) if transpose else ( + shape[0], shape[1])))(heter_data) f1 = lambda a: (a.T @ matrix) if transpose else (a @ matrix) f2 = partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, @@ -211,7 +218,6 @@ def test_heter_vmap(self, transpose, shape): self.assertTrue(bm.allclose(r1, r2, equal_nan=True)) - @parameterized.product( transpose=[True, False], shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], @@ -244,7 +250,8 @@ def test_heter_grad(self, transpose, shape): argnums=0) r1 = dense_f1(dense) r2 = jax.grad(sum_op(bm.sparse.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose ) rows, cols = bm.sparse.csr_to_coo(indices, indptr) if transpose: @@ -261,7 +268,8 @@ def test_heter_grad(self, transpose, shape): (dense @ m).sum())) r3 = dense_f2(matrix) r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose ) self.assertTrue(bm.allclose(r3, r4)) From 343d8a0fb045a8f86e4a9818427be050961f9e75 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 17 Feb 2024 10:22:49 +0800 Subject: [PATCH 08/23] Format codes and import --- brainpy/_src/math/event/_csr_matmat.py | 532 +++++++++--------- .../_src/math/event/tests/test_event_csrmm.py | 512 ++++++++--------- brainpy/_src/math/sparse/_csr_mm.py | 354 ++++++------ brainpy/_src/math/sparse/tests/test_csrmm.py | 510 ++++++++--------- 4 files changed, 946 insertions(+), 962 deletions(-) diff --git a/brainpy/_src/math/event/_csr_matmat.py b/brainpy/_src/math/event/_csr_matmat.py index c3e37f67d..024f1692f 100644 --- a/brainpy/_src/math/event/_csr_matmat.py +++ b/brainpy/_src/math/event/_csr_matmat.py @@ -1,128 +1,120 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Union, Tuple import jax -import numba import numpy as np -from jax import core, dtypes from jax import numpy as jnp -from jax.interpreters import ad, mlir, xla -from jax.lib import xla_client -from jaxlib import gpu_sparse +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, - XLACustomOp) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.op_register import (XLACustomOp) from brainpy._src.math.sparse._csr_mm import raw_csrmm_taichi as normal_csrmm -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.sparse._utils import csr_to_coo ti = import_taichi() __all__ = [ - 'csrmm', + 'csrmm', ] def csrmm( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, ): - """Product of CSR sparse matrix and a dense event matrix. + """Product of CSR sparse matrix and a dense event matrix. - Args: - data : array of shape ``(nse,)``, float. - indices : array of shape ``(nse,)`` - indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` - B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and - dtype ``data.dtype`` - shape : length-2 tuple representing the matrix shape - transpose : boolean specifying whether to transpose the sparse matrix - before computing. + Args: + data : array of shape ``(nse,)``, float. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. - Returns: - C : array of shape ``(shape[1] if transpose else shape[0], cols)`` - representing the matrix-matrix product product. - """ - return raw_event_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix-matrix product product. + """ + return raw_event_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] def raw_event_csrmm_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, ): - assert len(shape) == 2 - - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - - indices = as_jax(indices) - indptr = as_jax(indptr) - matrix = as_jax(matrix) - - assert data.ndim == indices.ndim == indptr.ndim == 1 - assert matrix.ndim == 2 - assert indptr.shape[0] == shape[0] + 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - # if the shape of indices is (0,), then we return a zero matrix - if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype), ] - - assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - if transpose: - if matrix.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csr_matmat_transpose_bool_homo_p - else: - prim = _event_csr_matmat_transpose_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csr_matmat_transpose_homo_p - else: - prim = _event_csr_matmat_transpose_heter_p + assert len(shape) == 2 + + data = jnp.atleast_1d(data) + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + + indices = as_jax(indices) + indptr = as_jax(indptr) + matrix = as_jax(matrix) + + assert data.ndim == indices.ndim == indptr.ndim == 1 + assert matrix.ndim == 2 + assert indptr.shape[0] == shape[0] + 1 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + # if the shape of indices is (0,), then we return a zero matrix + if indices.shape[0] == 0: + return [jnp.zeros(result_shape, dtype=data.dtype), ] + + assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + if transpose: + if matrix.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csr_matmat_transpose_bool_homo_p + else: + prim = _event_csr_matmat_transpose_bool_heter_p else: - if matrix.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csr_matmat_bool_homo_p - else: - prim = _event_csr_matmat_bool_heter_p - else: - if data.shape[0] == 1: - prim = _event_csr_matmat_homo_p - else: - prim = _event_csr_matmat_heter_p - return prim(data, - indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], - transpose=transpose, - shape=shape) + if data.shape[0] == 1: + prim = _event_csr_matmat_transpose_homo_p + else: + prim = _event_csr_matmat_transpose_heter_p + else: + if matrix.dtype == jnp.bool_: + if data.shape[0] == 1: + prim = _event_csr_matmat_bool_homo_p + else: + prim = _event_csr_matmat_bool_heter_p + else: + if data.shape[0] == 1: + prim = _event_csr_matmat_homo_p + else: + prim = _event_csr_matmat_heter_p + return prim(data, + indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], + transpose=transpose, + shape=shape) # CPU kernels @@ -133,17 +125,17 @@ def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -152,17 +144,17 @@ def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -171,13 +163,13 @@ def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] - out[row_i, col_k] = r + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] + out[row_i, col_k] = r @ti.kernel @@ -186,13 +178,13 @@ def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] - out[row_i, col_k] = r + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] + out[row_i, col_k] = r @ti.kernel @@ -201,17 +193,17 @@ def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -220,17 +212,17 @@ def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -239,14 +231,14 @@ def _event_csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + value = values[0] + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value @ti.kernel @@ -255,14 +247,14 @@ def _event_csr_matmat_bool_homo_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + value = values[0] + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value # GPU kernels @@ -273,17 +265,17 @@ def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -292,17 +284,17 @@ def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -311,13 +303,13 @@ def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] - out[row_i, col_k] = r + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] + out[row_i, col_k] = r @ti.kernel @@ -326,13 +318,13 @@ def _event_csr_matmat_bool_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] - out[row_i, col_k] = r + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] + out[row_i, col_k] = r @ti.kernel @@ -341,17 +333,17 @@ def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -360,17 +352,17 @@ def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -379,14 +371,14 @@ def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + value = values[0] + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value @ti.kernel @@ -395,53 +387,53 @@ def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + value = values[0] + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) + return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return normal_csrmm(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) + return normal_csrmm(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) def _event_csr_matmat_transpose( - ct, data, indices, indptr, matrix, *, outs, transpose, shape, + ct, data, indices, indptr, matrix, *, outs, transpose, shape, ): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(matrix): - ct_matrix = raw_event_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) - + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(matrix): + ct_matrix = raw_event_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = \ - raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] - ct_data = jnp.sum(ct[0] * ct_data) - else: # heter - matrix = jnp.asarray(matrix) - row, col = csr_to_coo(indices, indptr) - ct_data = (ct[0][row] * matrix[col]).sum(1) - return ct_data, indices, indptr, matrix + if data.aval.shape[0] == 1: # scalar + ct_data = \ + raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] + ct_data = jnp.sum(ct[0] * ct_data) + else: # heter + matrix = jnp.asarray(matrix) + row, col = csr_to_coo(indices, indptr) + ct_data = (ct[0][row] * matrix[col]).sum(1) + return ct_data, indices, indptr, matrix def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matmat_jvp_values, None, None, _event_csr_matmat_jvp_matrix) - prim.def_transpose_rule(_event_csr_matmat_transpose) - return prim + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matmat_jvp_values, None, None, _event_csr_matmat_jvp_matrix) + prim.def_transpose_rule(_event_csr_matmat_transpose) + return prim # transpose heter diff --git a/brainpy/_src/math/event/tests/test_event_csrmm.py b/brainpy/_src/math/event/tests/test_event_csrmm.py index 52e0378c6..c570d1537 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmm.py +++ b/brainpy/_src/math/event/tests/test_event_csrmm.py @@ -14,263 +14,263 @@ def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() - return func + return func class Test_csrmm(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmm, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - heter_data = bm.ones(indices.shape) * homo_data - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.event.csrmm(homo_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - c = bm.allclose(r1, r2, equal_nan=True) - if not c: - print(r1 - r2) - self.assertTrue(c) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo_vmap(self, transpose, shape, homo_data): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # vmap 'data' - f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - heter_data = bm.ones((10, indices.shape[0])) * homo_data - r1 = f1(heter_data) - r2 = f2(vmap_data) - self.assertTrue(bm.allclose(r1, r2)) - - # vmap 'events' - heter_data = bm.ones(indices.shape) * homo_data - f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmm, homo_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) - r3 = f3(matrix) - r4 = f4(matrix) - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo_grad(self, transpose, shape, homo_data): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - heter_data = bm.as_jax(rng.random((indices.shape))) - # matrix - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # grad data - dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() - if transpose else - ((dense * a) @ matrix).sum()), - argnums=0) - r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(bm.event.csrmm))( - homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose) - - self.assertTrue(bm.allclose(r1, r2)) - - # grad events matrix - dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() - if transpose else - ((dense * homo_data) @ m).sum()), - argnums=0) - r3 = dense_f2(matrix.astype(float)) - r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter(self, transpose, shape): - print(f'test_homo: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - heter_data = bm.as_jax(rng.random(indices.shape)) - - r1 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - r2 = bm.event.csrmm(heter_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter_vmap(self, transpose, shape): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # vmap 'data' - f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) - r1 = f1(vmap_data) - r2 = f2(vmap_data) - self.assertTrue(bm.allclose(r1, r2)) - - # vmap 'events' - heter_data = bm.ones(indices.shape) - f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - f4 = jax.vmap(partial(bm.event.csrmm, heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) - r3 = f3(matrix) - r4 = f4(matrix) - self.assertTrue(bm.allclose(r3, r4)) - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter_grad(self, transpose, shape): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - heter_data = bm.as_jax(rng.random((indices.shape))) - # matrix - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # grad data - r1 = jax.grad(sum_op(bm.sparse.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose) - r2 = jax.grad(sum_op(bm.event.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose) - self.assertTrue(bm.allclose(r1, r2)) - - # grad events matrix - r3 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_csrmm, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones(indices.shape) * homo_data + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.event.csrmm(homo_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + c = bm.allclose(r1, r2, equal_nan=True) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_vmap(self, transpose, shape, homo_data): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # vmap 'data' + f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax([homo_data] * 10) + heter_data = bm.ones((10, indices.shape[0])) * homo_data + r1 = f1(heter_data) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + # vmap 'events' + heter_data = bm.ones(indices.shape) * homo_data + f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f4 = jax.vmap(partial(bm.event.csrmm, homo_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) + r3 = f3(matrix) + r4 = f4(matrix) + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() + if transpose else + ((dense * a) @ matrix).sum()), + argnums=0) + r1 = dense_f1(homo_data) + r2 = jax.grad(sum_op(bm.event.csrmm))( + homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() + if transpose else + ((dense * homo_data) @ m).sum()), + argnums=0) + r3 = dense_f2(matrix.astype(float)) + r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( + homo_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter(self, transpose, shape): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.as_jax(rng.random(indices.shape)) + + r1 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r2 = bm.event.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_vmap(self, transpose, shape): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # vmap 'data' + f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) + r1 = f1(vmap_data) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + # vmap 'events' + heter_data = bm.ones(indices.shape) + f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f4 = jax.vmap(partial(bm.event.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) + r3 = f3(matrix) + r4 = f4(matrix) + self.assertTrue(bm.allclose(r3, r4)) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_grad(self, transpose, shape): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + r1 = jax.grad(sum_op(bm.sparse.csrmm))( + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) + r2 = jax.grad(sum_op(bm.event.csrmm))( + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + r3 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index eda72977a..c48e6a104 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -1,122 +1,114 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Union, Tuple import jax -import numba import numpy as np -from jax import core, dtypes from jax import numpy as jnp -from jax.interpreters import ad, mlir, xla -from jax.lib import xla_client -from jaxlib import gpu_sparse +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, - XLACustomOp) +from brainpy._src.math.op_register import (XLACustomOp) from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy.errors import GPUOperatorNotFound ti = import_taichi() __all__ = [ - 'csrmm', + 'csrmm', ] def csrmm( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, ): - """Product of CSR sparse matrix and a dense matrix. + """Product of CSR sparse matrix and a dense matrix. - Args: - data : array of shape ``(nse,)``. - indices : array of shape ``(nse,)`` - indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` - B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and - dtype ``data.dtype`` - shape : length-2 tuple representing the matrix shape - transpose : boolean specifying whether to transpose the sparse matrix - before computing. + Args: + data : array of shape ``(nse,)``. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. - Returns: - C : array of shape ``(shape[1] if transpose else shape[0], cols)`` - representing the matrix-matrix product product. - """ - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix-matrix product product. + """ + return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] def raw_csrmm_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, ): - assert len(shape) == 2 - - indices = as_jax(indices) - indptr = as_jax(indptr) - matrix = as_jax(matrix) - - data = jnp.atleast_1d(data) - - if matrix.dtype == jnp.bool_: - matrix = as_jax(matrix, dtype=data.dtype) - - if data.dtype != matrix.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {matrix.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == 1 - assert matrix.ndim == 2 - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - assert indptr.shape[0] == shape[0] + 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - # if the shape of indices is (0,), then we return a zero matrix - if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype), ] - - assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - if transpose: - if data.shape[0] == 1: - prim = _csr_matmat_transpose_homo_p - else: - prim = _csr_matmat_transpose_heter_p + assert len(shape) == 2 + + indices = as_jax(indices) + indptr = as_jax(indptr) + matrix = as_jax(matrix) + + data = jnp.atleast_1d(data) + + if matrix.dtype == jnp.bool_: + matrix = as_jax(matrix, dtype=data.dtype) + + if data.dtype != matrix.dtype: + raise TypeError('The types of data and vector should be the same. ' + f'But we got {data.dtype} != {matrix.dtype}.') + assert data.ndim == indices.ndim == indptr.ndim == 1 + assert matrix.ndim == 2 + data = jnp.atleast_1d(data) + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + assert indptr.shape[0] == shape[0] + 1 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + # if the shape of indices is (0,), then we return a zero matrix + if indices.shape[0] == 0: + return [jnp.zeros(result_shape, dtype=data.dtype), ] + + assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + if transpose: + if data.shape[0] == 1: + prim = _csr_matmat_transpose_homo_p else: - if data.shape[0] == 1: - prim = _csr_matmat_homo_p - else: - prim = _csr_matmat_heter_p - return prim(data, - indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], - transpose=transpose, - shape=shape) + prim = _csr_matmat_transpose_heter_p + else: + if data.shape[0] == 1: + prim = _csr_matmat_homo_p + else: + prim = _csr_matmat_heter_p + return prim(data, + indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], + transpose=transpose, + shape=shape) # CPU kernels @@ -127,16 +119,16 @@ def _csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val * matrix[row_j, col_i] + out[row_k, col_i] = r @ti.kernel @@ -145,12 +137,12 @@ def _csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * matrix[col_indices[j], col_k] - out[row_i, col_k] = r + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * matrix[col_indices[j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -159,16 +151,16 @@ def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -177,13 +169,13 @@ def _csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + value = values[0] + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value # GPU kernels @@ -194,16 +186,16 @@ def _csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val * matrix[row_j, col_i] + out[row_k, col_i] = r @ti.kernel @@ -212,12 +204,12 @@ def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * matrix[col_indices[j], col_k] - out[row_i, col_k] = r + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * matrix[col_indices[j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -226,16 +218,16 @@ def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + value = values[0] + for col_i in range(out.shape[1]): + for row_k in range(out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -244,51 +236,51 @@ def _csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + value = values[0] + for row_i in range(out.shape[0]): + for col_k in range(out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value def _csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return raw_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) + return raw_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) def _csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return raw_csrmm_taichi(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) + return raw_csrmm_taichi(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) def _csr_matmat_transpose( - ct, data, indices, indptr, matrix, *, outs, transpose, shape, + ct, data, indices, indptr, matrix, *, outs, transpose, shape, ): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(matrix): - ct_matrix = raw_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) - + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(matrix): + ct_matrix = raw_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] - ct_data = jnp.sum(ct[0] * ct_data) - else: # heter - matrix = jnp.asarray(matrix) - row, col = csr_to_coo(indices, indptr) - ct_data = (ct[0][row] * matrix[col]).sum(1) - return ct_data, indices, indptr, matrix + if data.aval.shape[0] == 1: # scalar + ct_data = raw_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] + ct_data = jnp.sum(ct[0] * ct_data) + else: # heter + matrix = jnp.asarray(matrix) + row, col = csr_to_coo(indices, indptr) + ct_data = (ct[0][row] * matrix[col]).sum(1) + return ct_data, indices, indptr, matrix def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_csr_matmat_jvp_values, None, None, _csr_matmat_jvp_matrix) - prim.def_transpose_rule(_csr_matmat_transpose) - return prim + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_csr_matmat_jvp_values, None, None, _csr_matmat_jvp_matrix) + prim.def_transpose_rule(_csr_matmat_transpose) + return prim # transpose heter diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py index bb006b1db..8c6a9fa29 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmm.py +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -14,264 +14,264 @@ def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() - return func + return func class Test_csrmm(parameterized.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_csrmm, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo(self, transpose, shape, homo_data): - print(f'test_homo: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - heter_data = bm.ones(indices.shape) * homo_data - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - c = bm.allclose(r1, r2, equal_nan=True) - if not c: - print(r1 - r2) - self.assertTrue(c) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] - ) - def test_homo_vmap(self, transpose, shape, homo_data): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - heter_data = bm.ones((10, indices.shape[0])) * homo_data - dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=(shape[1], shape[0]) if transpose else ( - shape[0], shape[1])))(heter_data) - - # vmap 'data' - f1 = jax.vmap(lambda a: (a.T @ matrix) if transpose else (a @ matrix)) - f2 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) - vmap_data = bm.as_jax([homo_data] * 10) - - r1 = f1(dense) - r2 = f2(vmap_data) - self.assertTrue(bm.allclose(r1, r2)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - homo_data=[-1., 1.] + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_csrmm, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones(indices.shape) * homo_data + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + c = bm.allclose(r1, r2, equal_nan=True) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_vmap(self, transpose, shape, homo_data): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones((10, indices.shape[0])) * homo_data + dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, + shape=(shape[1], shape[0]) if transpose else ( + shape[0], shape[1])))(heter_data) + + # vmap 'data' + f1 = jax.vmap(lambda a: (a.T @ matrix) if transpose else (a @ matrix)) + f2 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax([homo_data] * 10) + + r1 = f1(dense) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() + if transpose else + ((dense * a) @ matrix).sum()), + argnums=0) + r1 = dense_f1(homo_data) + r2 = jax.grad(sum_op(bm.sparse.csrmm))( + homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() + if transpose else + ((dense * homo_data) @ m).sum()), + argnums=0) + r3 = dense_f2(matrix.astype(float)) + r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + homo_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter(self, transpose, shape): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + heter_data = bm.as_jax(rng.random(indices.shape)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + c = bm.allclose(r1, r2, equal_nan=True) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_vmap(self, transpose, shape): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + heter_data = bm.as_jax(rng.random((10, indices.shape[0]))) + dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, + shape=(shape[1], shape[0]) if transpose else ( + shape[0], shape[1])))(heter_data) + + f1 = lambda a: (a.T @ matrix) if transpose else (a @ matrix) + f2 = partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r1 = jax.vmap(f1)(dense) + r2 = jax.vmap(f2)(heter_data) + + self.assertTrue(bm.allclose(r1, r2, equal_nan=True)) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_grad(self, transpose, shape): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + + heter_data = bm.as_jax(rng.random((indices.shape))) + dense = bm.sparse.csr_to_dense(heter_data, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + # matrix + matrix = rng.random((shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: ((a.T @ matrix).sum() + if transpose else + (a @ matrix).sum()), + argnums=0) + r1 = dense_f1(dense) + r2 = jax.grad(sum_op(bm.sparse.csrmm))( + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose ) - def test_homo_grad(self, transpose, shape, homo_data): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, - indices, - indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - heter_data = bm.as_jax(rng.random((indices.shape))) - # matrix - matrix = rng.random((shape[1], shape[2])) < 0.1 - matrix = bm.as_jax(matrix) - - # grad data - dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() - if transpose else - ((dense * a) @ matrix).sum()), - argnums=0) - r1 = dense_f1(homo_data) - r2 = jax.grad(sum_op(bm.sparse.csrmm))( - homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose) - - self.assertTrue(bm.allclose(r1, r2)) - - # grad events matrix - dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() - if transpose else - ((dense * homo_data) @ m).sum()), - argnums=0) - r3 = dense_f2(matrix.astype(float)) - r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + if transpose: + r1 = r1[cols, rows] + else: + r1 = r1[rows, cols] + print(r1 - r2) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad matrix + dense_f2 = jax.grad(lambda m: ((dense.T @ m).sum() + if transpose else + (dense @ m).sum())) + r3 = dense_f2(matrix) + r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose ) - def test_heter(self, transpose, shape): - print(f'test_homo: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) - matrix = bm.as_jax(matrix) - - heter_data = bm.as_jax(rng.random(indices.shape)) - - dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - - r1 = (dense.T @ matrix) if transpose else (dense @ matrix) - r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - c = bm.allclose(r1, r2, equal_nan=True) - if not c: - print(r1 - r2) - self.assertTrue(c) - - bm.clear_buffer_memory() - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter_vmap(self, transpose, shape): - print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - # matrix - rng = bm.random.RandomState(seed=seed) - matrix = rng.random((shape[1], shape[2])) - matrix = bm.as_jax(matrix) - - heter_data = bm.as_jax(rng.random((10, indices.shape[0]))) - dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, - shape=(shape[1], shape[0]) if transpose else ( - shape[0], shape[1])))(heter_data) - - f1 = lambda a: (a.T @ matrix) if transpose else (a @ matrix) - f2 = partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) - r1 = jax.vmap(f1)(dense) - r2 = jax.vmap(f2)(heter_data) - - self.assertTrue(bm.allclose(r1, r2, equal_nan=True)) - - @parameterized.product( - transpose=[True, False], - shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], - ) - def test_heter_grad(self, transpose, shape): - print(f'test_homo_grad: transpose: {transpose} shape = {shape}') - rng = bm.random.RandomState(seed=seed) - conn = bp.conn.FixedProb(0.3) - - # csr matrix - indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], - shape[1]).require( - 'pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - - heter_data = bm.as_jax(rng.random((indices.shape))) - dense = bm.sparse.csr_to_dense(heter_data, - indices, - indptr, - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) - # matrix - matrix = rng.random((shape[1], shape[2])) - matrix = bm.as_jax(matrix) - - # grad data - dense_f1 = jax.grad(lambda a: ((a.T @ matrix).sum() - if transpose else - (a @ matrix).sum()), - argnums=0) - r1 = dense_f1(dense) - r2 = jax.grad(sum_op(bm.sparse.csrmm))( - heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), - transpose=transpose - ) - rows, cols = bm.sparse.csr_to_coo(indices, indptr) - if transpose: - r1 = r1[cols, rows] - else: - r1 = r1[rows, cols] - print(r1 - r2) - - self.assertTrue(bm.allclose(r1, r2)) - - # grad matrix - dense_f2 = jax.grad(lambda m: ((dense.T @ m).sum() - if transpose else - (dense @ m).sum())) - r3 = dense_f2(matrix) - r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - heter_data, indices, indptr, matrix.astype(float), - shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose - ) - - self.assertTrue(bm.allclose(r3, r4)) - - bm.clear_buffer_memory() + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() From 7fe78b2772c17972b345ec1dd9d6ecb4e448d64b Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 17 Feb 2024 16:37:44 +0800 Subject: [PATCH 09/23] [math] Replace csr matmat heter operators with CUSPARSE --- brainpy/_src/math/sparse/_csr_mm.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index c48e6a104..1062331b0 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -7,6 +7,7 @@ import numpy as np from jax import numpy as jnp from jax.interpreters import ad +from jax.experimental.sparse import csr from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax @@ -92,16 +93,13 @@ def raw_csrmm_taichi( return [jnp.zeros(result_shape, dtype=data.dtype), ] assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - if transpose: - if data.shape[0] == 1: + if data.shape[0] != 1: + return _csr_matmat_heter_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose) + else: + if transpose: prim = _csr_matmat_transpose_homo_p else: - prim = _csr_matmat_transpose_heter_p - else: - if data.shape[0] == 1: prim = _csr_matmat_homo_p - else: - prim = _csr_matmat_heter_p return prim(data, indices, indptr, @@ -283,13 +281,13 @@ def _define_op(cpu_kernel, gpu_kernel): return prim -# transpose heter -_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, - gpu_kernel=_csr_matmat_transpose_heter_gpu) - -# no transpose heter -_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, - gpu_kernel=_csr_matmat_heter_gpu) +# # transpose heter +# _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, +# gpu_kernel=_csr_matmat_transpose_heter_gpu) +# +# # no transpose heter +# _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, +# gpu_kernel=_csr_matmat_heter_gpu) # transpose homo _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, @@ -298,3 +296,6 @@ def _define_op(cpu_kernel, gpu_kernel): # no transpose homo _csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo_cpu, gpu_kernel=_csr_matmat_homo_gpu) + +# heter CUSPARSE +_csr_matmat_heter_p = csr.csr_matmat_p \ No newline at end of file From b8ef61ac17f834573d04c908516a509f9f02289c Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 17 Feb 2024 16:53:10 +0800 Subject: [PATCH 10/23] Fix csr matmat bugs --- brainpy/_src/math/sparse/_csr_mm.py | 26 +++++++++++++------- brainpy/_src/math/sparse/tests/test_csrmm.py | 2 ++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index 1062331b0..12c4e703a 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -12,7 +12,7 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (XLACustomOp) +from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) from brainpy._src.math.sparse._utils import csr_to_coo ti = import_taichi() @@ -47,7 +47,19 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product product. """ - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + data = jnp.atleast_1d(data) + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + # if the shape of indices is (0,), then we return a zero matrix + + if data.shape[0] != 1: + if indices.shape[0] == 0: + return jnp.zeros(result_shape, dtype=data.dtype) + return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose) + else: + if indices.shape[0] == 0: + return [jnp.zeros(result_shape, dtype=data.dtype), ] + return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] def raw_csrmm_taichi( @@ -65,8 +77,6 @@ def raw_csrmm_taichi( indptr = as_jax(indptr) matrix = as_jax(matrix) - data = jnp.atleast_1d(data) - if matrix.dtype == jnp.bool_: matrix = as_jax(matrix, dtype=data.dtype) @@ -75,7 +85,7 @@ def raw_csrmm_taichi( f'But we got {data.dtype} != {matrix.dtype}.') assert data.ndim == indices.ndim == indptr.ndim == 1 assert matrix.ndim == 2 - data = jnp.atleast_1d(data) + if np.ndim(data) == 1: if data.shape[0] not in [1, indices.shape[0]]: raise ValueError('The size of data should be 1 or be consistent with indices.' @@ -88,9 +98,6 @@ def raw_csrmm_taichi( out_shape = shape[1] if transpose else shape[0] result_shape = (out_shape, matrix.shape[1]) - # if the shape of indices is (0,), then we return a zero matrix - if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype), ] assert matrix.shape[0] == (shape[0] if transpose else shape[1]) if data.shape[0] != 1: @@ -298,4 +305,5 @@ def _define_op(cpu_kernel, gpu_kernel): gpu_kernel=_csr_matmat_homo_gpu) # heter CUSPARSE -_csr_matmat_heter_p = csr.csr_matmat_p \ No newline at end of file +_csr_matmat_heter_p = csr.csr_matmat_p +register_general_batching(_csr_matmat_heter_p) diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py index 8c6a9fa29..c05b45cdf 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmm.py +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -179,6 +179,8 @@ def test_heter(self, transpose, shape): r1 = (dense.T @ matrix) if transpose else (dense @ matrix) r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + print(r2) + print(r1.shape, '-', r2.shape) c = bm.allclose(r1, r2, equal_nan=True) if not c: print(r1 - r2) From c37ce92e24b484234617f87f4c8bf5b52b032062 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 17 Feb 2024 20:33:30 +0800 Subject: [PATCH 11/23] Fix bug --- brainpy/_src/math/sparse/_csr_mm.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index 12c4e703a..c09c0ff76 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -58,7 +58,7 @@ def csrmm( return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose) else: if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype), ] + return jnp.zeros(result_shape, dtype=data.dtype) return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] @@ -100,6 +100,8 @@ def raw_csrmm_taichi( result_shape = (out_shape, matrix.shape[1]) assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + if indices.shape[0] == 0: + return [jnp.zeros(result_shape, dtype=data.dtype), ] if data.shape[0] != 1: return _csr_matmat_heter_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose) else: From 9ea9800609001dafade773b1ecd96f8386e243e7 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 17 Feb 2024 21:10:47 +0800 Subject: [PATCH 12/23] Update operator selection strategy for csr matmat homo -> taichi, heter(CPU) -> taichi, heter(GPU) -> cusparse --- brainpy/_src/math/sparse/_csr_mm.py | 46 ++++++++++++++--------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/_csr_mm.py index c09c0ff76..b5e21446c 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/_csr_mm.py @@ -5,6 +5,7 @@ import jax import numpy as np +import brainpy.math as bm from jax import numpy as jnp from jax.interpreters import ad from jax.experimental.sparse import csr @@ -47,19 +48,7 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product product. """ - data = jnp.atleast_1d(data) - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - # if the shape of indices is (0,), then we return a zero matrix - - if data.shape[0] != 1: - if indices.shape[0] == 0: - return jnp.zeros(result_shape, dtype=data.dtype) - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose) - else: - if indices.shape[0] == 0: - return jnp.zeros(result_shape, dtype=data.dtype) - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] def raw_csrmm_taichi( @@ -76,6 +65,7 @@ def raw_csrmm_taichi( indices = as_jax(indices) indptr = as_jax(indptr) matrix = as_jax(matrix) + data = jnp.atleast_1d(data) if matrix.dtype == jnp.bool_: matrix = as_jax(matrix, dtype=data.dtype) @@ -83,7 +73,6 @@ def raw_csrmm_taichi( if data.dtype != matrix.dtype: raise TypeError('The types of data and vector should be the same. ' f'But we got {data.dtype} != {matrix.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == 1 assert matrix.ndim == 2 if np.ndim(data) == 1: @@ -100,10 +89,19 @@ def raw_csrmm_taichi( result_shape = (out_shape, matrix.shape[1]) assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + if indices.shape[0] == 0: return [jnp.zeros(result_shape, dtype=data.dtype), ] + # homo -> taichi, + # heter(CPU) -> taichi, heter(GPU) -> cusparse if data.shape[0] != 1: - return _csr_matmat_heter_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose) + if bm.get_platform() == 'gpu': + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] + else: + if transpose: + prim = _csr_matmat_transpose_heter_p + else: + prim = _csr_matmat_heter_p else: if transpose: prim = _csr_matmat_transpose_homo_p @@ -290,13 +288,13 @@ def _define_op(cpu_kernel, gpu_kernel): return prim -# # transpose heter -# _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, -# gpu_kernel=_csr_matmat_transpose_heter_gpu) -# -# # no transpose heter -# _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, -# gpu_kernel=_csr_matmat_heter_gpu) +# transpose heter +_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, + gpu_kernel=_csr_matmat_transpose_heter_gpu) + +# no transpose heter +_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, + gpu_kernel=_csr_matmat_heter_gpu) # transpose homo _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, @@ -307,5 +305,5 @@ def _define_op(cpu_kernel, gpu_kernel): gpu_kernel=_csr_matmat_homo_gpu) # heter CUSPARSE -_csr_matmat_heter_p = csr.csr_matmat_p -register_general_batching(_csr_matmat_heter_p) +_csr_matmat_cusparse_p = csr.csr_matmat_p +register_general_batching(_csr_matmat_cusparse_p) From 8737b939a7ee959af443a4d6c56386edfff986e8 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 3 Mar 2024 16:53:25 +0800 Subject: [PATCH 13/23] [math] Update event csrmm and csrmm --- brainpy/_src/math/event/__init__.py | 1 + .../event/{_csr_matmat.py => csr_matmat.py} | 268 ++++++++-------- .../tests/event_csr_matmat_VS_csr_matmat.py | 285 ++++++++++++++++++ .../_src/math/event/tests/test_event_csrmm.py | 4 +- brainpy/_src/math/sparse/__init__.py | 1 + .../math/sparse/{_csr_mm.py => csr_mm.py} | 126 ++++---- .../csr_matmat_VS_cusparse_csr_matmat.py | 258 ++++++++++++++++ brainpy/_src/math/sparse/tests/test_csrmm.py | 4 +- brainpy/math/event.py | 1 + 9 files changed, 736 insertions(+), 212 deletions(-) rename brainpy/_src/math/event/{_csr_matmat.py => csr_matmat.py} (71%) create mode 100644 brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py rename brainpy/_src/math/sparse/{_csr_mm.py => csr_mm.py} (78%) create mode 100644 brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 9ebad3e94..91b479b62 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,2 +1,3 @@ from .csr_matvec import * +from .csr_matmat import * diff --git a/brainpy/_src/math/event/_csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py similarity index 71% rename from brainpy/_src/math/event/_csr_matmat.py rename to brainpy/_src/math/event/csr_matmat.py index 024f1692f..6936b495a 100644 --- a/brainpy/_src/math/event/_csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -12,8 +12,8 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (XLACustomOp) -from brainpy._src.math.sparse._csr_mm import raw_csrmm_taichi as normal_csrmm -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm +from brainpy._src.math.sparse.utils import csr_to_coo ti = import_taichi() @@ -125,17 +125,16 @@ def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -144,17 +143,16 @@ def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -163,13 +161,12 @@ def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -178,13 +175,12 @@ def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -194,16 +190,15 @@ def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -213,16 +208,15 @@ def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -232,13 +226,12 @@ def _event_csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value @ti.kernel @@ -248,13 +241,12 @@ def _event_csr_matmat_bool_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value # GPU kernels @@ -265,17 +257,16 @@ def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -284,17 +275,16 @@ def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val + out[row_k, col_i] = r @ti.kernel @@ -303,13 +293,12 @@ def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -318,13 +307,12 @@ def _event_csr_matmat_bool_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -334,16 +322,15 @@ def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -353,16 +340,15 @@ def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -372,13 +358,12 @@ def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value @ti.kernel @@ -388,13 +373,12 @@ def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): diff --git a/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py b/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py new file mode 100644 index 000000000..872c69e14 --- /dev/null +++ b/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py @@ -0,0 +1,285 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +size = [ + (100, 100, 100), + (100, 1000, 100), + (1000, 1000, 100), + (1000, 1000, 1000), + (100, 10000, 100), + (10000, 100, 1000), + (1000, 100, 10000), + (10000, 10000, 1000), + (10000, 1000, 10000), + (10000, 10000, 10000), + (20000, 20000, 20000), +] + +values_type = [ + 'heter', + 'homo', +] +events_type = ['bool', + 'float', + ] +transpose = [ + # True, + False +] + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +print(bm.get_platform()) + + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + return r + + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.event.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + return r + + +def test_sparse_csrmm(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) + matrix2_shape = (shape[1], shape[2]) + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') + matrix = rng.random(matrix2_shape) + matrix = bm.as_jax(matrix) + weight = 1. + + if events_type == 'float': + matrix = matrix.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time19 = time.time() + + result1 = result + + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time21 = time.time() + + result2 = result + + time22 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time39 = time.time() + + csrmm_time1 = (time1 - time0) * 1000 + csrmm_time2 = (time3 - time2) * 1000 + csrmm_time3 = (time5 - time4) * 1000 + csrmm_time4 = (time7 - time6) * 1000 + csrmm_time5 = (time9 - time8) * 1000 + csrmm_time6 = (time11 - time10) * 1000 + csrmm_time7 = (time13 - time12) * 1000 + csrmm_time8 = (time15 - time14) * 1000 + csrmm_time9 = (time17 - time16) * 1000 + csrmm_time10 = (time19 - time18) * 1000 + event_csrmm_time1 = (time21 - time20) * 1000 + event_csrmm_time2 = (time23 - time22) * 1000 + event_csrmm_time3 = (time25 - time24) * 1000 + event_csrmm_time4 = (time27 - time26) * 1000 + event_csrmm_time5 = (time29 - time28) * 1000 + event_csrmm_time6 = (time31 - time30) * 1000 + event_csrmm_time7 = (time33 - time32) * 1000 + event_csrmm_time8 = (time35 - time34) * 1000 + event_csrmm_time9 = (time37 - time36) * 1000 + event_csrmm_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('csrmm_1: ', csrmm_time1, 'ms') + print('csrmm_3: ', csrmm_time3, 'ms') + print('csrmm_5: ', csrmm_time5, 'ms') + print('csrmm_7: ', csrmm_time7, 'ms') + print('csrmm_9: ', csrmm_time9, 'ms') + print('event_csrmm_1: ', event_csrmm_time1, 'ms') + print('event_csrmm_3: ', event_csrmm_time3, 'ms') + print('event_csrmm_5: ', event_csrmm_time5, 'ms') + print('event_csrmm_7: ', event_csrmm_time7, 'ms') + print('event_csrmm_9: ', event_csrmm_time9, 'ms') + + r = bm.allclose(result1, result2) + if not r: + print('result1: ', result1) + print('result2: ', result2) + + return csrmm_time1, csrmm_time2, csrmm_time3, csrmm_time4, csrmm_time5, \ + csrmm_time6, csrmm_time7, csrmm_time8, csrmm_time9, csrmm_time10, \ + event_csrmm_time1, event_csrmm_time2, event_csrmm_time3, event_csrmm_time4, event_csrmm_time5, \ + event_csrmm_time6, event_csrmm_time7, event_csrmm_time8, event_csrmm_time9, event_csrmm_time10 + + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame( + columns=['shape', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', + 'csrmm time1(ms)', 'csrmm time2(ms)', 'csrmm time3(ms)', 'csrmm time4(ms)', + 'csrmm time5(ms)', + 'csrmm time6(ms)', 'csrmm time7(ms)', 'csrmm time8(ms)', 'csrmm time9(ms)', + 'csrmm time10(ms)', + 'event_csrmm time1(ms)', 'event_csrmm time2(ms)', 'event_csrmm time3(ms)', 'event_csrmm time4(ms)', + 'event_csrmm time5(ms)', + 'event_csrmm time6(ms)', 'event_csrmm time7(ms)', 'event_csrmm time8(ms)', 'event_csrmm time9(ms)', + 'event_csrmm time10(ms)']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, csrmm_time_5, \ + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, csrmm_time_10, \ + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, event_csrmm_time_5, \ + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, event_csrmm_time_10 = test_sparse_csrmm( + shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.05, shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, + _transpose, + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, + csrmm_time_5, + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, + csrmm_time_10, + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, + event_csrmm_time_5, + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, + event_csrmm_time_10] + df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, csrmm_time_5, \ + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, csrmm_time_10, \ + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, event_csrmm_time_5, \ + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, event_csrmm_time_10 = test_sparse_csrmm( + shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.05, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, + _transpose, + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, + csrmm_time_5, + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, + csrmm_time_10, + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, + event_csrmm_time_5, + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, + event_csrmm_time_10] + df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) diff --git a/brainpy/_src/math/event/tests/test_event_csrmm.py b/brainpy/_src/math/event/tests/test_event_csrmm.py index c570d1537..12a35ef34 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmm.py +++ b/brainpy/_src/math/event/tests/test_event_csrmm.py @@ -140,7 +140,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(bm.event.csrmm))( - homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -152,7 +152,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r3 = dense_f2(matrix.astype(float)) r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), + bm.asarray([homo_data]), indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index 14256cbce..13c9e1e28 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,6 +1,7 @@ # from ._coo_mv import * # from ._bsr_mv import * from .csr_mv import * +from .csr_mm import * from .utils import * from .bsr_mm import * from .jax_prim import * diff --git a/brainpy/_src/math/sparse/_csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py similarity index 78% rename from brainpy/_src/math/sparse/_csr_mm.py rename to brainpy/_src/math/sparse/csr_mm.py index b5e21446c..dba93a797 100644 --- a/brainpy/_src/math/sparse/_csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -8,13 +8,14 @@ import brainpy.math as bm from jax import numpy as jnp from jax.interpreters import ad +from jax.core import Tracer from jax.experimental.sparse import csr from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) -from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.math.sparse.utils import csr_to_coo ti = import_taichi() @@ -32,7 +33,8 @@ def csrmm( shape: Tuple[int, int], transpose: bool = False, ): - """Product of CSR sparse matrix and a dense matrix. + """ + Product of CSR sparse matrix and a dense matrix. Args: data : array of shape ``(nse,)``. @@ -46,7 +48,7 @@ def csrmm( Returns: C : array of shape ``(shape[1] if transpose else shape[0], cols)`` - representing the matrix-matrix product product. + representing the matrix-matrix product. """ return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] @@ -99,7 +101,7 @@ def raw_csrmm_taichi( return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] else: if transpose: - prim = _csr_matmat_transpose_heter_p + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] else: prim = _csr_matmat_heter_p else: @@ -124,16 +126,15 @@ def _csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val * matrix[row_j, col_i] + out[row_k, col_i] = r @ti.kernel @@ -142,12 +143,11 @@ def _csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * matrix[col_indices[j], col_k] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * matrix[col_indices[j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -157,15 +157,14 @@ def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -175,12 +174,11 @@ def _csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value # GPU kernels @@ -191,16 +189,15 @@ def _csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + val = 0. + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + val = values[j] + r += val * matrix[row_j, col_i] + out[row_k, col_i] = r @ti.kernel @@ -209,12 +206,11 @@ def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * matrix[col_indices[j], col_k] - out[row_i, col_k] = r + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * matrix[col_indices[j], col_k] + out[row_i, col_k] = r @ti.kernel @@ -224,15 +220,14 @@ def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i in range(out.shape[1]): - for row_k in range(out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] - break - out[row_k, col_i] = r + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + r = 0. + for row_j in range(matrix.shape[0]): + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + r += value * matrix[row_j, col_i] + break + out[row_k, col_i] = r @ti.kernel @@ -242,12 +237,11 @@ def _csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): value = values[0] - for row_i in range(out.shape[0]): - for col_k in range(out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value def _csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py new file mode 100644 index 000000000..79c8bef0a --- /dev/null +++ b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py @@ -0,0 +1,258 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +s = [1000, 5000, 10000, 15000, 20000, 25000, 30000] +p = [0.1, 0.2, 0.3, 0.4, 0.5] + +size = [ + (100, 100, 100), + (100, 1000, 100), + (1000, 1000, 100), + (1000, 1000, 1000), + (100, 10000, 100), + (10000, 100, 1000), + (1000, 100, 10000), + (10000, 10000, 1000), + (10000, 1000, 10000), + (10000, 10000, 10000), + (20000, 20000, 20000), +] + +values_type = [ + 'heter' + ] +events_type = ['float'] +transpose = [ + True, + # False + ] + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +print(bm.get_platform()) + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmm_taichi(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method=None) + return r + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method='cusparse') + return r + +def test_sparse_csrmm(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) + matrix2_shape = (shape[1], shape[2]) + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') + matrix = rng.random(matrix2_shape) + matrix = bm.as_jax(matrix) + weight = 1. + + + + if events_type == 'float': + matrix = matrix.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time19 = time.time() + + result1 = result + + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time21 = time.time() + + result2 = result + + time22 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time39 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') + print(bm.allclose(result1, result2)) + + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5 , shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5 , shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py index c05b45cdf..e4346c841 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmm.py +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -133,7 +133,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r1 = dense_f1(homo_data) r2 = jax.grad(sum_op(bm.sparse.csrmm))( - homo_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r1, r2)) @@ -145,7 +145,7 @@ def test_homo_grad(self, transpose, shape, homo_data): argnums=0) r3 = dense_f2(matrix.astype(float)) r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( - homo_data, indices, indptr, matrix.astype(float), + bm.asarray([homo_data]), indices, indptr, matrix.astype(float), shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) self.assertTrue(bm.allclose(r3, r4)) diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 02e98b8f3..3b4b5ed1e 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,3 +1,4 @@ from brainpy._src.math.event import ( csrmv as csrmv, + csrmm as csrmm, ) From b9e558445f888f5af4df6ca0ea31bc4e02ba8e92 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sun, 3 Mar 2024 17:25:18 +0800 Subject: [PATCH 14/23] Update sparse.py --- brainpy/math/sparse.py | 1 - 1 file changed, 1 deletion(-) diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 10abbeb92..8a209901f 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -5,7 +5,6 @@ from brainpy._src.math.sparse import ( csrmv, csrmm, - coomv, seg_matmul, From 861b340e50fa0c9835db2e7e51bfad6e8cf89e0b Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 4 Mar 2024 16:26:19 +0800 Subject: [PATCH 15/23] Update event csr matvec --- brainpy/_src/math/event/csr_matmat.py | 24 ++++++++++++------------ brainpy/_src/math/event/csr_matvec.py | 10 ++-------- brainpy/_src/math/sparse/csr_mm.py | 8 ++++---- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index 6936b495a..949ed46e3 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -133,7 +133,7 @@ def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: val = values[j] - r += val + r += val * matrix[row_j, col_i] out[row_k, col_i] = r @@ -151,7 +151,7 @@ def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: val = values[j] - r += val + r += val * matrix[row_j, col_i] out[row_k, col_i] = r @@ -196,9 +196,9 @@ def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), if matrix[row_j, col_i] != 0.: for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] + r += matrix[row_j, col_i] break - out[row_k, col_i] = r + out[row_k, col_i] = r * value @ti.kernel @@ -214,9 +214,9 @@ def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), if matrix[row_j, col_i]: for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] + r += matrix[row_j, col_i] break - out[row_k, col_i] = r + out[row_k, col_i] = r * value @ti.kernel @@ -264,7 +264,7 @@ def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), val = 0. for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - val = values[j] + val = values[j] * matrix[row_j, col_i] r += val out[row_k, col_i] = r @@ -282,7 +282,7 @@ def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), val = 0. for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - val = values[j] + val = values[j] * matrix[row_j, col_i] r += val out[row_k, col_i] = r @@ -328,9 +328,9 @@ def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), if matrix[row_j, col_i] != 0.: for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] + r += matrix[row_j, col_i] break - out[row_k, col_i] = r + out[row_k, col_i] = r * value @ti.kernel @@ -346,9 +346,9 @@ def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), if matrix[row_j, col_i]: for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] + r += matrix[row_j, col_i] break - out[row_k, col_i] = r + out[row_k, col_i] = r * value @ti.kernel diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py index 9890838e7..d44783450 100644 --- a/brainpy/_src/math/event/csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -131,10 +131,7 @@ def raw_csrmv_taichi( else: prim = _event_csrmv_transpose_bool_heter_p else: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_homo_p - else: - prim = _event_csrmv_transpose_heter_p + return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) else: if events.dtype == jnp.bool_: if data.shape[0] == 1: @@ -142,10 +139,7 @@ def raw_csrmv_taichi( else: prim = _event_csrmv_bool_heter_p else: - if data.shape[0] == 1: - prim = _event_csrmv_homo_p - else: - prim = _event_csrmv_heter_p + return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) # computing return prim(data, diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index dba93a797..6fa1aedc4 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -162,9 +162,9 @@ def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), for row_j in range(matrix.shape[0]): for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] + r += matrix[row_j, col_i] break - out[row_k, col_i] = r + out[row_k, col_i] = r * value @ti.kernel @@ -225,9 +225,9 @@ def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), for row_j in range(matrix.shape[0]): for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += value * matrix[row_j, col_i] + r += matrix[row_j, col_i] break - out[row_k, col_i] = r + out[row_k, col_i] = r * value @ti.kernel From fff64dbcb73a95a59d0322f9e55a4e0fa040f492 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 7 Mar 2024 15:16:27 +0800 Subject: [PATCH 16/23] Update new transpose taichi kernels of csrmm and event csrmm --- brainpy/_src/math/event/csr_matmat.py | 270 ++++++-------------------- brainpy/_src/math/sparse/csr_mm.py | 146 ++++---------- 2 files changed, 93 insertions(+), 323 deletions(-) diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index 949ed46e3..a8f55afbb 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -93,10 +93,7 @@ def raw_event_csrmm_taichi( else: prim = _event_csr_matmat_transpose_bool_heter_p else: - if data.shape[0] == 1: - prim = _event_csr_matmat_transpose_homo_p - else: - prim = _event_csr_matmat_transpose_heter_p + return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) else: if matrix.dtype == jnp.bool_: if data.shape[0] == 1: @@ -104,10 +101,7 @@ def raw_event_csrmm_taichi( else: prim = _event_csr_matmat_bool_heter_p else: - if data.shape[0] == 1: - prim = _event_csr_matmat_homo_p - else: - prim = _event_csr_matmat_heter_p + return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) return prim(data, indices, indptr, @@ -117,50 +111,42 @@ def raw_event_csrmm_taichi( shape=shape) -# CPU kernels +# taichi kernels @ti.kernel -def _event_csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _event_csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. for row_j in range(matrix.shape[0]): if matrix[row_j, col_i] != 0.: - val = 0. for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + out[row_k, col_i] += values[j] * matrix[row_j, col_i] @ti.kernel -def _event_csr_matmat_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _event_csr_matmat_transpose_bool_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. for row_j in range(matrix.shape[0]): if matrix[row_j, col_i]: - val = 0. for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + out[row_k, col_i] += values[j] * matrix[row_j, col_i] @ti.kernel -def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _event_csr_matmat_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): @@ -170,11 +156,11 @@ def _event_csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), @ti.kernel -def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _event_csr_matmat_bool_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): r = 0. for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): @@ -184,179 +170,41 @@ def _event_csr_matmat_bool_heter_cpu(values: ti.types.ndarray(ndim=1), @ti.kernel -def _event_csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += matrix[row_j, col_i] - break - out[row_k, col_i] = r * value - - -@ti.kernel -def _event_csr_matmat_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += matrix[row_j, col_i] - break - out[row_k, col_i] = r * value - - -@ti.kernel -def _event_csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value - - -@ti.kernel -def _event_csr_matmat_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value - - -# GPU kernels - -@ti.kernel -def _event_csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] * matrix[row_j, col_i] - r += val - out[row_k, col_i] = r - - -@ti.kernel -def _event_csr_matmat_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] * matrix[row_j, col_i] - r += val - out[row_k, col_i] = r - - -@ti.kernel -def _event_csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] * matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - -@ti.kernel -def _event_csr_matmat_bool_heter_gpu(values: ti.types.ndarray(ndim=1), +def _event_csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1), col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] * matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - -@ti.kernel -def _event_csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): value = values[0] for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. for row_j in range(matrix.shape[0]): if matrix[row_j, col_i] != 0.: for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += matrix[row_j, col_i] - break - out[row_k, col_i] = r * value + out[row_k, col_i] += value * matrix[row_j, col_i] @ti.kernel -def _event_csr_matmat_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _event_csr_matmat_transpose_bool_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. for row_j in range(matrix.shape[0]): if matrix[row_j, col_i]: for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += matrix[row_j, col_i] - break - out[row_k, col_i] = r * value + out[row_k, col_i] += value * matrix[row_j, col_i] @ti.kernel -def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _event_csr_matmat_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): r = 0. @@ -367,11 +215,11 @@ def _event_csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), @ti.kernel -def _event_csr_matmat_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): r = 0. @@ -421,33 +269,33 @@ def _define_op(cpu_kernel, gpu_kernel): # transpose heter -_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter_cpu, - gpu_kernel=_event_csr_matmat_transpose_heter_gpu) +_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter, + gpu_kernel=_event_csr_matmat_transpose_heter) # no transpose heter -_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter_cpu, - gpu_kernel=_event_csr_matmat_heter_gpu) +_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter, + gpu_kernel=_event_csr_matmat_heter) # transpose homo -_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo_cpu, - gpu_kernel=_event_csr_matmat_transpose_homo_gpu) +_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo, + gpu_kernel=_event_csr_matmat_transpose_homo) # no transpose homo -_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo_cpu, - gpu_kernel=_event_csr_matmat_homo_gpu) +_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo, + gpu_kernel=_event_csr_matmat_homo) # bool transpose heter -_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter_cpu, - gpu_kernel=_event_csr_matmat_transpose_bool_heter_gpu) +_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter, + gpu_kernel=_event_csr_matmat_transpose_bool_heter) # bool no transpose heter -_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter_cpu, - gpu_kernel=_event_csr_matmat_bool_heter_gpu) +_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter, + gpu_kernel=_event_csr_matmat_bool_heter) # bool transpose homo -_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo_cpu, - gpu_kernel=_event_csr_matmat_transpose_bool_homo_gpu) +_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo, + gpu_kernel=_event_csr_matmat_transpose_bool_homo) # bool no transpose homo -_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo_cpu, - gpu_kernel=_event_csr_matmat_bool_homo_gpu) +_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo, + gpu_kernel=_event_csr_matmat_bool_homo) diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index 6fa1aedc4..33aa803df 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -5,11 +5,9 @@ import jax import numpy as np -import brainpy.math as bm from jax import numpy as jnp -from jax.interpreters import ad -from jax.core import Tracer from jax.experimental.sparse import csr +from jax.interpreters import ad from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax @@ -95,15 +93,9 @@ def raw_csrmm_taichi( if indices.shape[0] == 0: return [jnp.zeros(result_shape, dtype=data.dtype), ] # homo -> taichi, - # heter(CPU) -> taichi, heter(GPU) -> cusparse + # heter -> cusparse if data.shape[0] != 1: - if bm.get_platform() == 'gpu': - return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] - else: - if transpose: - return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] - else: - prim = _csr_matmat_heter_p + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] else: if transpose: prim = _csr_matmat_transpose_homo_p @@ -118,94 +110,27 @@ def raw_csrmm_taichi( shape=shape) -# CPU kernels - -@ti.kernel -def _csr_matmat_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - val = 0. - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r - - -@ti.kernel -def _csr_matmat_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * matrix[col_indices[j], col_k] - out[row_i, col_k] = r - - -@ti.kernel -def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - r += matrix[row_j, col_i] - break - out[row_k, col_i] = r * value - - -@ti.kernel -def _csr_matmat_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value - - -# GPU kernels +# taichi kernels @ti.kernel -def _csr_matmat_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. for row_j in range(matrix.shape[0]): - val = 0. for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - val = values[j] - r += val * matrix[row_j, col_i] - out[row_k, col_i] = r + out[row_k, col_i] += values[j] * matrix[row_j, col_i] @ti.kernel -def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): r = 0. for j in range(row_ptr[row_i], row_ptr[row_i + 1]): @@ -214,28 +139,25 @@ def _csr_matmat_heter_gpu(values: ti.types.ndarray(ndim=1), @ti.kernel -def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - r = 0. for row_j in range(matrix.shape[0]): for j in range(row_ptr[row_j], row_ptr[row_j + 1]): if col_indices[j] == row_k: - r += matrix[row_j, col_i] - break - out[row_k, col_i] = r * value + out[row_k, col_i] += value * matrix[row_j, col_i] @ti.kernel -def _csr_matmat_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _csr_matmat_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): r = 0. @@ -283,20 +205,20 @@ def _define_op(cpu_kernel, gpu_kernel): # transpose heter -_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter_cpu, - gpu_kernel=_csr_matmat_transpose_heter_gpu) +_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter, + gpu_kernel=_csr_matmat_transpose_heter) # no transpose heter -_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter_cpu, - gpu_kernel=_csr_matmat_heter_gpu) +_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, + gpu_kernel=_csr_matmat_heter) # transpose homo -_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, - gpu_kernel=_csr_matmat_transpose_homo_gpu) +_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo, + gpu_kernel=_csr_matmat_transpose_homo) # no transpose homo -_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo_cpu, - gpu_kernel=_csr_matmat_homo_gpu) +_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, + gpu_kernel=_csr_matmat_homo) # heter CUSPARSE _csr_matmat_cusparse_p = csr.csr_matmat_p From 9f3b9b24decef5d19a132c85031127ee97e66ff8 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 7 Mar 2024 21:46:18 +0800 Subject: [PATCH 17/23] Update csrmm --- brainpy/_src/math/sparse/csr_mm.py | 50 ++-- .../csr_matmat_VS_cusparse_csr_matmat.py | 245 ++++++++++-------- 2 files changed, 171 insertions(+), 124 deletions(-) diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index 33aa803df..dfea2a6b0 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -92,6 +92,7 @@ def raw_csrmm_taichi( if indices.shape[0] == 0: return [jnp.zeros(result_shape, dtype=data.dtype), ] + # homo -> taichi, # heter -> cusparse if data.shape[0] != 1: @@ -118,11 +119,11 @@ def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), matrix: ti.types.ndarray(ndim=2), out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += values[j] * matrix[row_j, col_i] + for row_i in range(row_ptr.shape[0] - 1): + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + col = col_indices[i] + for j in range(out.shape[1]): + out[col, j] += values[row_i] * matrix[row_i, j] @ti.kernel @@ -139,17 +140,32 @@ def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), @ti.kernel -def _csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): +def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(row_ptr.shape[0] - 1): + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + col = col_indices[i] + for j in range(out.shape[1]): + out[col, j] += value * matrix[row_i, j] + + +@ti.kernel +def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += value * matrix[row_j, col_i] + for row_i in range(row_ptr.shape[0] - 1): + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + col = col_indices[i] + for j in range(out.shape[1]): + out[col, j] += value * matrix[row_i, j] @ti.kernel @@ -213,8 +229,8 @@ def _define_op(cpu_kernel, gpu_kernel): gpu_kernel=_csr_matmat_heter) # transpose homo -_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo, - gpu_kernel=_csr_matmat_transpose_homo) +_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) # no transpose homo _csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py index 79c8bef0a..f11275b17 100644 --- a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py +++ b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py @@ -14,163 +14,180 @@ bm.set_platform('cpu') -s = [1000, 5000, 10000, 15000, 20000, 25000, 30000] -p = [0.1, 0.2, 0.3, 0.4, 0.5] +SPARSITY = 0.05 size = [ - (100, 100, 100), - (100, 1000, 100), - (1000, 1000, 100), - (1000, 1000, 1000), - (100, 10000, 100), - (10000, 100, 1000), - (1000, 100, 10000), - (10000, 10000, 1000), - (10000, 1000, 10000), - (10000, 10000, 10000), - (20000, 20000, 20000), + (100, 100, 100), + (100, 1000, 100), + (1000, 1000, 100), + (1000, 1000, 1000), + (100, 10000, 100), + (10000, 100, 1000), + (1000, 100, 10000), + (10000, 10000, 1000), + (10000, 1000, 10000), + (10000, 10000, 10000), + (20000, 20000, 20000), ] values_type = [ - 'heter' - ] + 'homo', + # 'heter' +] events_type = ['float'] transpose = [ - True, - # False - ] + True, + False +] -ITERATION = 100 +ITERATION = 10 if bm.get_platform() == 'cpu': - ITERATION = 10 + ITERATION = 3 print(bm.get_platform()) + @partial(jax.jit, static_argnums=(4, 5)) def csrmm_taichi(weight, indices, indptr, matrix, shape, transpose): r = 0 for i in range(ITERATION): r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method=None) return r - + + @partial(jax.jit, static_argnums=(4, 5)) def csrmm(weight, indices, indptr, matrix, shape, transpose): r = 0 for i in range(ITERATION): - r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method='cusparse') + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method='jaxlib') return r + def test_sparse_csrmm(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) matrix2_shape = (shape[1], shape[2]) - indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(SPARSITY, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') matrix = rng.random(matrix2_shape) matrix = bm.as_jax(matrix) weight = 1. - - + + heter_data = bm.ones(indices.shape) * weight if events_type == 'float': matrix = matrix.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data + # if values_type == 'heter': + # weight = heter_data - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time0 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time1 = time.time() time2 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time3 = time.time() time4 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time5 = time.time() time6 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time7 = time.time() time8 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time9 = time.time() - + time10 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time11 = time.time() - + time12 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time13 = time.time() - + time14 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time15 = time.time() - + time16 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time17 = time.time() - + time18 = time.time() - result = jax.block_until_ready(csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready( + csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time19 = time.time() - + result1 = result - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time20 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time21 = time.time() - + result2 = result - + time22 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time23 = time.time() time24 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time25 = time.time() time26 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time27 = time.time() time28 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time29 = time.time() - + time30 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time31 = time.time() - + time32 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time33 = time.time() - + time34 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time35 = time.time() - + time36 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time37 = time.time() - + time38 = time.time() - result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 @@ -206,53 +223,67 @@ def test_sparse_csrmm(shape, values_type, events_type, transpose): print('brainpylib_9: ', brainpy_time9, 'ms') print(bm.allclose(result1, result2)) - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5, \ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10, \ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe -df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) +df = pd.DataFrame( + columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', + 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', + 'taichi aot time10(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): for shape in size: for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.5 , shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, + _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, + taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, + taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): for shape in size: for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, _values_type, _events_type, _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.5 , shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, + _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, + taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, + taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) From c298f48bc7951bd9c49d03c6b9915eb0b52c5233 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 8 Mar 2024 15:02:50 +0800 Subject: [PATCH 18/23] accelerate csrmm homo --- brainpy/_src/math/sparse/csr_mm.py | 258 ++++++++++++++--------------- 1 file changed, 122 insertions(+), 136 deletions(-) diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index dfea2a6b0..5f4e07d44 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -12,10 +12,10 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) -from brainpy._src.math.sparse.utils import csr_to_coo +from brainpy._src.math.op_register import (XLACustomOp) +from brainpy.errors import PackageMissingError -ti = import_taichi() +ti = import_taichi(error_if_not_found=False) __all__ = [ 'csrmm', @@ -98,144 +98,130 @@ def raw_csrmm_taichi( if data.shape[0] != 1: return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] else: + if ti is None: + raise PackageMissingError.by_purpose('taichi', 'customzied sparse matrix multiplication') if transpose: prim = _csr_matmat_transpose_homo_p else: prim = _csr_matmat_homo_p - return prim(data, - indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], - transpose=transpose, - shape=shape) + r = prim(indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)], + transpose=transpose, + shape=shape) + return [r[0] * data] # taichi kernels - -@ti.kernel -def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i in range(row_ptr.shape[0] - 1): - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - col = col_indices[i] - for j in range(out.shape[1]): - out[col, j] += values[row_i] * matrix[row_i, j] - - -@ti.kernel -def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * matrix[col_indices[j], col_k] - out[row_i, col_k] = r - - -@ti.kernel -def _csr_matmat_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - col = col_indices[i] - for j in range(out.shape[1]): - out[col, j] += value * matrix[row_i, j] - - -@ti.kernel -def _csr_matmat_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i in range(row_ptr.shape[0] - 1): - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - col = col_indices[i] - for j in range(out.shape[1]): - out[col, j] += value * matrix[row_i, j] - - -@ti.kernel -def _csr_matmat_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value - - -def _csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return raw_csrmm_taichi(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) - - -def _csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return raw_csrmm_taichi(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) - - -def _csr_matmat_transpose( - ct, data, indices, indptr, matrix, *, outs, transpose, shape, -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(matrix): - ct_matrix = raw_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) +if ti is not None: + + # @ti.kernel + # def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), + # col_indices: ti.types.ndarray(ndim=1), + # row_ptr: ti.types.ndarray(ndim=1), + # matrix: ti.types.ndarray(ndim=2), + # out: ti.types.ndarray(ndim=2)): + # for row_i in range(row_ptr.shape[0] - 1): + # for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + # col = col_indices[i] + # for j in range(out.shape[1]): + # out[col, j] += values[row_i] * matrix[row_i, j] + # + # + # @ti.kernel + # def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), + # col_indices: ti.types.ndarray(ndim=1), + # row_ptr: ti.types.ndarray(ndim=1), + # matrix: ti.types.ndarray(ndim=2), + # out: ti.types.ndarray(ndim=2)): + # for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + # r = 0. + # for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + # r += values[j] * matrix[col_indices[j], col_k] + # out[row_i, col_k] = r + + @ti.kernel + def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + # matrix: (k, n) + # sparse matrix: (m, k) + for j in range(out.shape[1]): # parallize along the n dimension + for row_i in range(row_ptr.shape[0] - 1): # loop along the m dimension + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[i], j] += matrix[row_i, j] + + + @ti.kernel + def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + m = row_ptr.shape[0] - 1 + n = matrix.shape[1] + for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[i], j] += matrix[row_i, j] + + + @ti.kernel + def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + # matrix: (k, n) + # sparse matrix: (m, k) + m, n = out.shape + for row_i, col_k in ti.ndrange(m, n): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r + + + def _csr_matmat_jvp_matrix(mat_dot, col_indices, row_ptr, matrix, *, outs, transpose, shape): + if transpose: + return _csr_matmat_transpose_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs) else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] - ct_data = jnp.sum(ct[0] * ct_data) - else: # heter - matrix = jnp.asarray(matrix) - row, col = csr_to_coo(indices, indptr) - ct_data = (ct[0][row] * matrix[col]).sum(1) - return ct_data, indices, indptr, matrix - - -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_csr_matmat_jvp_values, None, None, _csr_matmat_jvp_matrix) - prim.def_transpose_rule(_csr_matmat_transpose) - return prim - - -# transpose heter -_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter, - gpu_kernel=_csr_matmat_transpose_heter) - -# no transpose heter -_csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, - gpu_kernel=_csr_matmat_heter) - -# transpose homo -_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, - gpu_kernel=_csr_matmat_transpose_homo_gpu) - -# no transpose homo -_csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, - gpu_kernel=_csr_matmat_homo) - -# heter CUSPARSE -_csr_matmat_cusparse_p = csr.csr_matmat_p -register_general_batching(_csr_matmat_cusparse_p) + return _csr_matmat_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs) + + + def _csr_matmat_transpose( + ct, col_indices, row_ptr, matrix, *, outs, transpose, shape, + ): + if ad.is_undefined_primal(col_indices) or ad.is_undefined_primal(row_ptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + assert ad.is_undefined_primal(matrix) + ct_matrix = _csr_matmat_transpose_homo_p(col_indices, row_ptr, ct[0], + shape=shape, + transpose=not transpose, + outs=[jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)]) + return col_indices, row_ptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix[0]) + + + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(None, None, _csr_matmat_jvp_matrix) + prim.def_transpose_rule(_csr_matmat_transpose) + return prim + + + # # transpose heter + # _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter, + # gpu_kernel=_csr_matmat_transpose_heter) + # + # # no transpose heter + # _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, + # gpu_kernel=_csr_matmat_heter) + + # transpose homo + _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) + + # no transpose homo + _csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) + + # heter CUSPARSE + _csr_matmat_cusparse_p = csr.csr_matmat_p From f355d18f87b5a6941d265417b19929cc18062025 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 8 Mar 2024 15:12:15 +0800 Subject: [PATCH 19/23] upgrade --- brainpy/_src/math/sparse/csr_mm.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index 5f4e07d44..d0ea66ca1 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -12,7 +12,7 @@ from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (XLACustomOp) +from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) from brainpy.errors import PackageMissingError ti = import_taichi(error_if_not_found=False) @@ -115,7 +115,6 @@ def raw_csrmm_taichi( # taichi kernels if ti is not None: - # @ti.kernel # def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), # col_indices: ti.types.ndarray(ndim=1), @@ -128,7 +127,6 @@ def raw_csrmm_taichi( # for j in range(out.shape[1]): # out[col, j] += values[row_i] * matrix[row_i, j] # - # # @ti.kernel # def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), # col_indices: ti.types.ndarray(ndim=1), @@ -140,6 +138,15 @@ def raw_csrmm_taichi( # for j in range(row_ptr[row_i], row_ptr[row_i + 1]): # r += values[j] * matrix[col_indices[j], col_k] # out[row_i, col_k] = r + # + # # transpose heter + # _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter, + # gpu_kernel=_csr_matmat_transpose_heter) + # + # # no transpose heter + # _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, + # gpu_kernel=_csr_matmat_heter) + @ti.kernel def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), @@ -148,8 +155,10 @@ def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=2)): # matrix: (k, n) # sparse matrix: (m, k) - for j in range(out.shape[1]): # parallize along the n dimension - for row_i in range(row_ptr.shape[0] - 1): # loop along the m dimension + n = out.shape[1] + m = row_ptr.shape[0] - 1 + for j in range(n): # parallize along the n dimension + for row_i in range(m): # loop along the m dimension for i in range(row_ptr[row_i], row_ptr[row_i + 1]): out[col_indices[i], j] += matrix[row_i, j] @@ -208,14 +217,6 @@ def _define_op(cpu_kernel, gpu_kernel): return prim - # # transpose heter - # _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter, - # gpu_kernel=_csr_matmat_transpose_heter) - # - # # no transpose heter - # _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, - # gpu_kernel=_csr_matmat_heter) - # transpose homo _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, gpu_kernel=_csr_matmat_transpose_homo_gpu) @@ -225,3 +226,4 @@ def _define_op(cpu_kernel, gpu_kernel): # heter CUSPARSE _csr_matmat_cusparse_p = csr.csr_matmat_p + register_general_batching(_csr_matmat_cusparse_p) From c92ace488e511061bb0900cde76e9f310e413abb Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Fri, 8 Mar 2024 17:13:50 +0800 Subject: [PATCH 20/23] Update csr_matmat_VS_cusparse_csr_matmat.py --- .../csr_matmat_VS_cusparse_csr_matmat.py | 104 ++++++++++++++++-- 1 file changed, 92 insertions(+), 12 deletions(-) diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py index f11275b17..79fe387ba 100644 --- a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py +++ b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py @@ -1,20 +1,21 @@ -# from jax_taichi import jax_taichi_call - +import os import time from functools import partial -import os -import brainpy as bp -import brainpy.math as bm import jax import jax.numpy as jnp -import numpy as np import pandas as pd -import taichi as ti +from jax.experimental.sparse import csr -bm.set_platform('cpu') +import brainpy as bp +import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi +from brainpy._src.math.interoperability import as_jax +from brainpy._src.math.op_register import XLACustomOp -SPARSITY = 0.05 +ti = import_taichi(error_if_not_found=False) + +bm.set_platform('cpu') size = [ (100, 100, 100), @@ -39,19 +40,98 @@ True, False ] - ITERATION = 10 +SPARSITY = 0.05 + if bm.get_platform() == 'cpu': ITERATION = 3 print(bm.get_platform()) +@ti.kernel +def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + # matrix: (k, n) + # sparse matrix: (m, k) + n = out.shape[1] + m = row_ptr.shape[0] - 1 + for j in range(n): # parallize along the n dimension + for row_i in range(m): # loop along the m dimension + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[i], j] += matrix[row_i, j] + + +@ti.kernel +def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + m = row_ptr.shape[0] - 1 + n = matrix.shape[1] + for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[i], j] += matrix[row_i, j] + + +@ti.kernel +def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + # matrix: (k, n) + # sparse matrix: (m, k) + m, n = out.shape + for row_i, col_k in ti.ndrange(m, n): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r + + +# transpose homo +_csr_matmat_transpose_homo_p = XLACustomOp(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) + +# no transpose homo +_csr_matmat_homo_p = XLACustomOp(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) + + +def taichi_csrmm(weight, indices, indptr, matrix, shape, transpose): + indices = as_jax(indices) + indptr = as_jax(indptr) + matrix = as_jax(matrix) + weight = jnp.atleast_1d(weight) + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + if transpose: + prim = _csr_matmat_transpose_homo_p + else: + prim = _csr_matmat_homo_p + r = prim(indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)], + transpose=transpose, + shape=shape) + return r[0] * weight + + +def jaxlib_csrmm(weight, indices, indptr, matrix, shape, transpose): + indices = as_jax(indices) + indptr = as_jax(indptr) + matrix = as_jax(matrix) + weight = jnp.atleast_1d(weight) + return csr.csr_matmat_p.bind(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + + @partial(jax.jit, static_argnums=(4, 5)) def csrmm_taichi(weight, indices, indptr, matrix, shape, transpose): r = 0 for i in range(ITERATION): - r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method=None) + r += taichi_csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) return r @@ -59,7 +139,7 @@ def csrmm_taichi(weight, indices, indptr, matrix, shape, transpose): def csrmm(weight, indices, indptr, matrix, shape, transpose): r = 0 for i in range(ITERATION): - r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose, method='jaxlib') + r += jaxlib_csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) return r From 2589820c9ba8f352fc5e23b899a6662fbab57cb3 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 9 Mar 2024 21:12:31 +0800 Subject: [PATCH 21/23] update --- .../csr_matmat_VS_cusparse_csr_matmat.py | 110 ++++++++---------- 1 file changed, 48 insertions(+), 62 deletions(-) diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py index 79fe387ba..61c3f8c4f 100644 --- a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py +++ b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py @@ -2,20 +2,20 @@ import time from functools import partial +import numpy as np + +os.environ["CUDA_VISIBLE_DEVICES"] = "2" + import jax import jax.numpy as jnp import pandas as pd +import taichi as ti from jax.experimental.sparse import csr import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import XLACustomOp - -ti = import_taichi(error_if_not_found=False) -bm.set_platform('cpu') +bm.set_platform('gpu') size = [ (100, 100, 100), @@ -37,7 +37,7 @@ ] events_type = ['float'] transpose = [ - True, + # True, False ] ITERATION = 10 @@ -92,17 +92,17 @@ def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1), # transpose homo -_csr_matmat_transpose_homo_p = XLACustomOp(cpu_kernel=_csr_matmat_transpose_homo_cpu, - gpu_kernel=_csr_matmat_transpose_homo_gpu) +_csr_matmat_transpose_homo_p = bm.XLACustomOp(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) # no transpose homo -_csr_matmat_homo_p = XLACustomOp(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) +_csr_matmat_homo_p = bm.XLACustomOp(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) def taichi_csrmm(weight, indices, indptr, matrix, shape, transpose): - indices = as_jax(indices) - indptr = as_jax(indptr) - matrix = as_jax(matrix) + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + matrix = bm.as_jax(matrix) weight = jnp.atleast_1d(weight) out_shape = shape[1] if transpose else shape[0] result_shape = (out_shape, matrix.shape[1]) @@ -120,9 +120,9 @@ def taichi_csrmm(weight, indices, indptr, matrix, shape, transpose): def jaxlib_csrmm(weight, indices, indptr, matrix, shape, transpose): - indices = as_jax(indices) - indptr = as_jax(indptr) - matrix = as_jax(matrix) + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + matrix = bm.as_jax(matrix) weight = jnp.atleast_1d(weight) return csr.csr_matmat_p.bind(weight, indices, indptr, matrix, shape=shape, transpose=transpose) @@ -321,49 +321,35 @@ def test_sparse_csrmm(shape, values_type, events_type, transpose): 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) -### RECTANGULAR MATRIX -if (bm.get_platform() == 'cpu'): - for shape in size: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, - _values_type, - _events_type, - _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, - _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, - taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, - taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) - -if (bm.get_platform() == 'gpu'): - for shape in size: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, - _values_type, - _events_type, - _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, - _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, - taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, - taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) +for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, + _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, + taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, + taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + + print(shape, _values_type, _events_type, _transpose) + a = np.asarray([taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, + taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, + taichi_aot_time_10]) + b = np.asarray([brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]) + print(a) + print(b) + print(a.sum() / b.sum()) + df.to_csv(f'{PATH}/csrmm_{bm.get_platform()}.csv', index=False) From 1c284183deb0453ec4701ae4a83dcce271dd67d1 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sun, 10 Mar 2024 15:28:42 +0800 Subject: [PATCH 22/23] updates --- .../csr_matmat_VS_cusparse_csr_matmat.py | 554 +++++++++++------- 1 file changed, 332 insertions(+), 222 deletions(-) diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py index 61c3f8c4f..d40a93247 100644 --- a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py +++ b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py @@ -1,6 +1,5 @@ import os import time -from functools import partial import numpy as np @@ -8,7 +7,6 @@ import jax import jax.numpy as jnp -import pandas as pd import taichi as ti from jax.experimental.sparse import csr @@ -43,9 +41,6 @@ ITERATION = 10 SPARSITY = 0.05 -if bm.get_platform() == 'cpu': - ITERATION = 3 - print(bm.get_platform()) @@ -116,7 +111,97 @@ def taichi_csrmm(weight, indices, indptr, matrix, shape, transpose): outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)], transpose=transpose, shape=shape) - return r[0] * weight + return r[0] + + +SHARED_MEM_SIZE = 256 + + +# @ti.kernel +# def _csr_matmat_homo2(col_indices: ti.types.ndarray(ndim=1), +# row_ptr: ti.types.ndarray(ndim=1), +# matrix: ti.types.ndarray(ndim=2), +# out: ti.types.ndarray(ndim=2)): +# m, n = out.shape +# l = col_indices.shape[0] +# ti.loop_config(block_dim=SHARED_MEM_SIZE) +# # for i_col, i_row in ti.ndrange(n, m): +# for i in range(m * n): +# indices_sm = ti.simt.block.SharedArray((SHARED_MEM_SIZE,), ti.int32) +# +# # one block threads compute will SHARED_MEM_SIZE columns +# i_row = i // SHARED_MEM_SIZE +# i_col = i % SHARED_MEM_SIZE +# +# index_start = row_ptr[i_row] +# end_border = row_ptr[i_row + 1] +# n_share = (end_border - index_start) // SHARED_MEM_SIZE +# n_last = end_border - index_start - n_share * SHARED_MEM_SIZE +# +# r = 0. +# for i_share in range(n_share): +# indices_sm[i_col] = col_indices[i_col + i_share * SHARED_MEM_SIZE] +# ti.simt.block.sync() +# # compute +# for j in range(SHARED_MEM_SIZE): +# r += matrix[indices_sm[j], i_col] +# indices_sm[i_col] = col_indices[ti.min(i_col + n_share * SHARED_MEM_SIZE, l)] +# ti.simt.block.sync() +# for j in range(n_last): +# r += matrix[indices_sm[j], i_col] +# out[i_row, i_col] += r + + +@ti.kernel +def _csr_matmat_homo2(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + m, n = out.shape + l = col_indices.shape[0] + ti.loop_config(block_dim=SHARED_MEM_SIZE) + + indices_sm = ti.simt.block.SharedArray((SHARED_MEM_SIZE,), ti.int32) + # for i_col, i_row in ti.ndrange(n, m): + for i in ti.ndrange(n * m): + # i_col = ti.global_thread_idx() % n + # i_row = ti.global_thread_idx() // n + i_col = i % n + i_row = i // n + i_share = i_col % SHARED_MEM_SIZE + + index_start = row_ptr[i_row] + end_border = row_ptr[i_row + 1] + n_share = (end_border - index_start) // SHARED_MEM_SIZE + n_last = end_border - index_start - n_share * SHARED_MEM_SIZE + + r = 0. + for k in range(n_share): + indices_sm[i_share] = col_indices[index_start + i_share + k * SHARED_MEM_SIZE] + ti.simt.block.sync() + for j in range(SHARED_MEM_SIZE): + r += matrix[indices_sm[j], i_col] + indices_sm[i_share] = col_indices[ti.min(index_start + i_share + n_share * SHARED_MEM_SIZE, l)] + ti.simt.block.sync() + for j in range(n_last): + r += matrix[indices_sm[j], i_col] + + # final results + out[i_row, i_col] += r + + +# no transpose homo +_csr_matmat_homo2_p = bm.XLACustomOp(gpu_kernel=_csr_matmat_homo2) + + +def taichi_csrmm2(weight, indices, indptr, matrix, shape, transpose): + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + matrix = bm.as_jax(matrix) + weight = jnp.atleast_1d(weight) + result_shape = (shape[1] if transpose else shape[0], matrix.shape[1]) + return _csr_matmat_homo2_p(indices, indptr, matrix, transpose=transpose, shape=shape, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)])[0] def jaxlib_csrmm(weight, indices, indptr, matrix, shape, transpose): @@ -127,229 +212,254 @@ def jaxlib_csrmm(weight, indices, indptr, matrix, shape, transpose): return csr.csr_matmat_p.bind(weight, indices, indptr, matrix, shape=shape, transpose=transpose) -@partial(jax.jit, static_argnums=(4, 5)) -def csrmm_taichi(weight, indices, indptr, matrix, shape, transpose): - r = 0 - for i in range(ITERATION): - r += taichi_csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) - return r - +def generate_op(op): + def csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + t = op(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + r += t + return r -@partial(jax.jit, static_argnums=(4, 5)) -def csrmm(weight, indices, indptr, matrix, shape, transpose): - r = 0 - for i in range(ITERATION): - r += jaxlib_csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) - return r + return jax.jit(csrmm, static_argnames=('shape', 'transpose')) -def test_sparse_csrmm(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) +def run_spmm_homo(op, shape, transpose, use_heter_data=False): + bm.random.seed(1234) matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) matrix2_shape = (shape[1], shape[2]) indices, indptr = bp.conn.FixedProb(SPARSITY, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') - matrix = rng.random(matrix2_shape) - matrix = bm.as_jax(matrix) + matrix = bm.as_jax(bm.random.random(matrix2_shape)) weight = 1. + if use_heter_data: + weight = bm.ones(indices.shape) * weight + + result = jax.block_until_ready(op(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + times = [] + for i in range(10): + time0 = time.time() + result = jax.block_until_ready(op(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time1 = time.time() + times.append(time1 - time0) + return np.asarray(times).mean(), result - heter_data = bm.ones(indices.shape) * weight - - if events_type == 'float': - matrix = matrix.astype(bm.float32) - # if values_type == 'heter': - # weight = heter_data - - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - - time0 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time1 = time.time() - - time2 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time3 = time.time() - - time4 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time5 = time.time() - - time6 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time9 = time.time() - - time10 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time11 = time.time() - - time12 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time13 = time.time() - - time14 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time15 = time.time() - - time16 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time17 = time.time() - - time18 = time.time() - result = jax.block_until_ready( - csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time19 = time.time() - - result1 = result - - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - - time20 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time21 = time.time() - - result2 = result - - time22 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time23 = time.time() - - time24 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time25 = time.time() - - time26 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time27 = time.time() - - time28 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time29 = time.time() - - time30 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time31 = time.time() - - time32 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time33 = time.time() - - time34 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time35 = time.time() - - time36 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time37 = time.time() - - time38 = time.time() - result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) - time39 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - taichi_aot_time6 = (time11 - time10) * 1000 - taichi_aot_time7 = (time13 - time12) * 1000 - taichi_aot_time8 = (time15 - time14) * 1000 - taichi_aot_time9 = (time17 - time16) * 1000 - taichi_aot_time10 = (time19 - time18) * 1000 - brainpy_time1 = (time21 - time20) * 1000 - brainpy_time2 = (time23 - time22) * 1000 - brainpy_time3 = (time25 - time24) * 1000 - brainpy_time4 = (time27 - time26) * 1000 - brainpy_time5 = (time29 - time28) * 1000 - brainpy_time6 = (time31 - time30) * 1000 - brainpy_time7 = (time33 - time32) * 1000 - brainpy_time8 = (time35 - time34) * 1000 - brainpy_time9 = (time37 - time36) * 1000 - brainpy_time10 = (time39 - time38) * 1000 - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('taichi_aot_7: ', taichi_aot_time7, 'ms') - print('taichi_aot_9: ', taichi_aot_time9, 'ms') - print('brainpylib_1: ', brainpy_time1, 'ms') - print('brainpylib_3: ', brainpy_time3, 'ms') - print('brainpylib_5: ', brainpy_time5, 'ms') - print('brainpylib_7: ', brainpy_time7, 'ms') - print('brainpylib_9: ', brainpy_time9, 'ms') - print(bm.allclose(result1, result2)) - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5, \ - taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10, \ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ - brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 - - -PATH = os.path.dirname(os.path.abspath(__file__)) - -# init dataframe -df = pd.DataFrame( - columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', - 'taichi aot time5(ms)', - 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', - 'taichi aot time10(ms)', - 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) +bm.clear_taichi_aot_caches() for shape in size: - for _values_type in values_type: - for _events_type in events_type: - for _transpose in transpose: - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, - _values_type, - _events_type, - _transpose) - # append to dataframe - df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, - _transpose, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, - taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, - taichi_aot_time_10, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] - - print(shape, _values_type, _events_type, _transpose) - a = np.asarray([taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, - taichi_aot_time_5, - taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, - taichi_aot_time_10]) - b = np.asarray([brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, - brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]) - print(a) - print(b) - print(a.sum() / b.sum()) - df.to_csv(f'{PATH}/csrmm_{bm.get_platform()}.csv', index=False) + for _transpose in transpose: + cusparse_times, cusparse_r = run_spmm_homo(generate_op(jaxlib_csrmm), shape, _transpose, use_heter_data=True) + homo1_times, homo1_r = run_spmm_homo(generate_op(taichi_csrmm), shape, _transpose) + homo2_times, homo2_r = run_spmm_homo(generate_op(taichi_csrmm2), shape, _transpose) + print(jnp.allclose(cusparse_r, homo1_r), jnp.allclose(cusparse_r, homo2_r)) + print(f'shape={shape}, transpose={_transpose}, cusparse/homo1 = {cusparse_times / homo1_times}, ' + f'cusparse/homo2 = {cusparse_times / homo2_times}') + print(homo2_r) + +# def test_sparse_csrmm(shape, values_type, events_type, transpose): +# rng = bm.random.RandomState(seed=1234) +# matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) +# matrix2_shape = (shape[1], shape[2]) +# indices, indptr = bp.conn.FixedProb(SPARSITY, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') +# matrix = rng.random(matrix2_shape) +# matrix = bm.as_jax(matrix) +# weight = 1. +# +# heter_data = bm.ones(indices.shape) * weight +# +# if events_type == 'float': +# matrix = matrix.astype(bm.float32) +# # if values_type == 'heter': +# # weight = heter_data +# +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# +# time0 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time1 = time.time() +# +# time2 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time3 = time.time() +# +# time4 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time5 = time.time() +# +# time6 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time7 = time.time() +# +# time8 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time9 = time.time() +# +# time10 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time11 = time.time() +# +# time12 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time13 = time.time() +# +# time14 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time15 = time.time() +# +# time16 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time17 = time.time() +# +# time18 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time19 = time.time() +# +# result1 = result +# +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# +# time20 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time21 = time.time() +# +# result2 = result +# +# time22 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time23 = time.time() +# +# time24 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time25 = time.time() +# +# time26 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time27 = time.time() +# +# time28 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time29 = time.time() +# +# time30 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time31 = time.time() +# +# time32 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time33 = time.time() +# +# time34 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time35 = time.time() +# +# time36 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time37 = time.time() +# +# time38 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time39 = time.time() +# +# taichi_aot_time1 = (time1 - time0) * 1000 +# taichi_aot_time2 = (time3 - time2) * 1000 +# taichi_aot_time3 = (time5 - time4) * 1000 +# taichi_aot_time4 = (time7 - time6) * 1000 +# taichi_aot_time5 = (time9 - time8) * 1000 +# taichi_aot_time6 = (time11 - time10) * 1000 +# taichi_aot_time7 = (time13 - time12) * 1000 +# taichi_aot_time8 = (time15 - time14) * 1000 +# taichi_aot_time9 = (time17 - time16) * 1000 +# taichi_aot_time10 = (time19 - time18) * 1000 +# brainpy_time1 = (time21 - time20) * 1000 +# brainpy_time2 = (time23 - time22) * 1000 +# brainpy_time3 = (time25 - time24) * 1000 +# brainpy_time4 = (time27 - time26) * 1000 +# brainpy_time5 = (time29 - time28) * 1000 +# brainpy_time6 = (time31 - time30) * 1000 +# brainpy_time7 = (time33 - time32) * 1000 +# brainpy_time8 = (time35 - time34) * 1000 +# brainpy_time9 = (time37 - time36) * 1000 +# brainpy_time10 = (time39 - time38) * 1000 +# print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) +# print('taichi_aot_1: ', taichi_aot_time1, 'ms') +# print('taichi_aot_3: ', taichi_aot_time3, 'ms') +# print('taichi_aot_5: ', taichi_aot_time5, 'ms') +# print('taichi_aot_7: ', taichi_aot_time7, 'ms') +# print('taichi_aot_9: ', taichi_aot_time9, 'ms') +# print('brainpylib_1: ', brainpy_time1, 'ms') +# print('brainpylib_3: ', brainpy_time3, 'ms') +# print('brainpylib_5: ', brainpy_time5, 'ms') +# print('brainpylib_7: ', brainpy_time7, 'ms') +# print('brainpylib_9: ', brainpy_time9, 'ms') +# print(bm.allclose(result1, result2)) +# +# return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5, \ +# taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10, \ +# brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ +# brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +# PATH = os.path.dirname(os.path.abspath(__file__)) +# +# # init dataframe +# df = pd.DataFrame( +# columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', +# 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', +# 'taichi aot time5(ms)', +# 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', +# 'taichi aot time10(ms)', +# 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', +# 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) +# +# for shape in size: +# for _values_type in values_type: +# for _events_type in events_type: +# for _transpose in transpose: +# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ +# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ +# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ +# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, +# _values_type, +# _events_type, +# _transpose) +# # append to dataframe +# df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, +# _transpose, +# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, +# taichi_aot_time_5, +# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, +# taichi_aot_time_10, +# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, +# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] +# +# print(shape, _values_type, _events_type, _transpose) +# a = np.asarray([taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, +# taichi_aot_time_5, +# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, +# taichi_aot_time_10]) +# b = np.asarray([brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, +# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]) +# print(a) +# print(b) +# print(a.sum() / b.sum()) +# df.to_csv(f'{PATH}/csrmm_{bm.get_platform()}.csv', index=False) From 68e25f9164c9fca315bf2e277bfc88fcb490ebc5 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Sat, 25 May 2024 20:52:18 +0800 Subject: [PATCH 23/23] Fix bugs --- brainpy/_src/math/event/csr_matmat.py | 35 +++++++++++++++++---------- brainpy/_src/math/sparse/csr_mm.py | 8 +++--- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index a8f55afbb..33677691a 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -7,13 +7,15 @@ import numpy as np from jax import numpy as jnp from jax.interpreters import ad +from jax.experimental.sparse import csr from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (XLACustomOp) +from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm from brainpy._src.math.sparse.utils import csr_to_coo +from brainpy._src.math.defaults import float_ ti = import_taichi() @@ -86,23 +88,26 @@ def raw_event_csrmm_taichi( return [jnp.zeros(result_shape, dtype=data.dtype), ] assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - if transpose: + + # homo -> taichi + # heter -> cusparse + if data.shape[0] != 1: if matrix.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csr_matmat_transpose_bool_homo_p + # change dtype to float + matrix = matrix.astype(float_) + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] + else: + if transpose: + if matrix.dtype == jnp.bool_: + prim = _event_csr_matmat_transpose_homo_p else: - prim = _event_csr_matmat_transpose_bool_heter_p + return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) - else: - if matrix.dtype == jnp.bool_: - if data.shape[0] == 1: + if matrix.dtype == jnp.bool_: prim = _event_csr_matmat_bool_homo_p else: - prim = _event_csr_matmat_bool_heter_p - else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) - return prim(data, + return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) + return prim(data, indices, indptr, matrix, @@ -299,3 +304,7 @@ def _define_op(cpu_kernel, gpu_kernel): # bool no transpose homo _event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo, gpu_kernel=_event_csr_matmat_bool_homo) + +# heter CUSPARSE +_csr_matmat_cusparse_p = csr.csr_matmat_p +register_general_batching(_csr_matmat_cusparse_p) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index d0ea66ca1..47c24fa4f 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -147,7 +147,6 @@ def raw_csrmm_taichi( # _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, # gpu_kernel=_csr_matmat_heter) - @ti.kernel def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), row_ptr: ti.types.ndarray(ndim=1), @@ -203,10 +202,9 @@ def _csr_matmat_transpose( if ad.is_undefined_primal(col_indices) or ad.is_undefined_primal(row_ptr): raise ValueError("Cannot transpose with respect to sparse indices.") assert ad.is_undefined_primal(matrix) - ct_matrix = _csr_matmat_transpose_homo_p(col_indices, row_ptr, ct[0], - shape=shape, - transpose=not transpose, - outs=[jax.ShapeDtypeStruct(matrix.shape, matrix.dtype)]) + ct_matrix = raw_csrmm_taichi(jnp.ones(1), col_indices, row_ptr, ct[0], + shape=shape, + transpose=not transpose) return col_indices, row_ptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix[0])