diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 9ebad3e9..91b479b6 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,2 +1,3 @@ from .csr_matvec import * +from .csr_matmat import * diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py new file mode 100644 index 00000000..33677691 --- /dev/null +++ b/brainpy/_src/math/event/csr_matmat.py @@ -0,0 +1,310 @@ +# -*- coding: utf-8 -*- + + +from typing import Union, Tuple + +import jax +import numpy as np +from jax import numpy as jnp +from jax.interpreters import ad +from jax.experimental.sparse import csr + +from brainpy._src.dependency_check import import_taichi +from brainpy._src.math.interoperability import as_jax +from brainpy._src.math.ndarray import Array +from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) +from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm +from brainpy._src.math.sparse.utils import csr_to_coo +from brainpy._src.math.defaults import float_ + +ti = import_taichi() + +__all__ = [ + 'csrmm', +] + + +def csrmm( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + """Product of CSR sparse matrix and a dense event matrix. + + Args: + data : array of shape ``(nse,)``, float. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix-matrix product product. + """ + return raw_event_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + + +def raw_event_csrmm_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + assert len(shape) == 2 + + data = jnp.atleast_1d(data) + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + + indices = as_jax(indices) + indptr = as_jax(indptr) + matrix = as_jax(matrix) + + assert data.ndim == indices.ndim == indptr.ndim == 1 + assert matrix.ndim == 2 + assert indptr.shape[0] == shape[0] + 1 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + # if the shape of indices is (0,), then we return a zero matrix + if indices.shape[0] == 0: + return [jnp.zeros(result_shape, dtype=data.dtype), ] + + assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + + # homo -> taichi + # heter -> cusparse + if data.shape[0] != 1: + if matrix.dtype == jnp.bool_: + # change dtype to float + matrix = matrix.astype(float_) + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] + else: + if transpose: + if matrix.dtype == jnp.bool_: + prim = _event_csr_matmat_transpose_homo_p + else: + return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) + else: + if matrix.dtype == jnp.bool_: + prim = _event_csr_matmat_bool_homo_p + else: + return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) + return prim(data, + indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], + transpose=transpose, + shape=shape) + + +# taichi kernels + +@ti.kernel +def _event_csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + out[row_k, col_i] += values[j] * matrix[row_j, col_i] + + +@ti.kernel +def _event_csr_matmat_transpose_bool_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + out[row_k, col_i] += values[j] * matrix[row_j, col_i] + + +@ti.kernel +def _event_csr_matmat_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r + + +@ti.kernel +def _event_csr_matmat_bool_heter(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += values[row_j] * matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r + + +@ti.kernel +def _event_csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i] != 0.: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + out[row_k, col_i] += value * matrix[row_j, col_i] + + +@ti.kernel +def _event_csr_matmat_transpose_bool_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): + for row_j in range(matrix.shape[0]): + if matrix[row_j, col_i]: + for j in range(row_ptr[row_j], row_ptr[row_j + 1]): + if col_indices[j] == row_k: + out[row_k, col_i] += value * matrix[row_j, col_i] + + +@ti.kernel +def _event_csr_matmat_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k] != 0.: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value + + +@ti.kernel +def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + value = values[0] + for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + if matrix[col_indices[row_j], col_k]: + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r * value + + +def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): + return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) + + +def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): + return normal_csrmm(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) + + +def _event_csr_matmat_transpose( + ct, data, indices, indptr, matrix, *, outs, transpose, shape, +): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(matrix): + ct_matrix = raw_event_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] + return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) + + else: + if type(ct[0]) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = \ + raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] + ct_data = jnp.sum(ct[0] * ct_data) + else: # heter + matrix = jnp.asarray(matrix) + row, col = csr_to_coo(indices, indptr) + ct_data = (ct[0][row] * matrix[col]).sum(1) + return ct_data, indices, indptr, matrix + + +def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matmat_jvp_values, None, None, _event_csr_matmat_jvp_matrix) + prim.def_transpose_rule(_event_csr_matmat_transpose) + return prim + + +# transpose heter +_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter, + gpu_kernel=_event_csr_matmat_transpose_heter) + +# no transpose heter +_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter, + gpu_kernel=_event_csr_matmat_heter) + +# transpose homo +_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo, + gpu_kernel=_event_csr_matmat_transpose_homo) + +# no transpose homo +_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo, + gpu_kernel=_event_csr_matmat_homo) + +# bool transpose heter +_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter, + gpu_kernel=_event_csr_matmat_transpose_bool_heter) + +# bool no transpose heter +_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter, + gpu_kernel=_event_csr_matmat_bool_heter) + +# bool transpose homo +_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo, + gpu_kernel=_event_csr_matmat_transpose_bool_homo) + +# bool no transpose homo +_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo, + gpu_kernel=_event_csr_matmat_bool_homo) + +# heter CUSPARSE +_csr_matmat_cusparse_p = csr.csr_matmat_p +register_general_batching(_csr_matmat_cusparse_p) \ No newline at end of file diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py index 9890838e..d4478345 100644 --- a/brainpy/_src/math/event/csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -131,10 +131,7 @@ def raw_csrmv_taichi( else: prim = _event_csrmv_transpose_bool_heter_p else: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_homo_p - else: - prim = _event_csrmv_transpose_heter_p + return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) else: if events.dtype == jnp.bool_: if data.shape[0] == 1: @@ -142,10 +139,7 @@ def raw_csrmv_taichi( else: prim = _event_csrmv_bool_heter_p else: - if data.shape[0] == 1: - prim = _event_csrmv_homo_p - else: - prim = _event_csrmv_heter_p + return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) # computing return prim(data, diff --git a/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py b/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py new file mode 100644 index 00000000..872c69e1 --- /dev/null +++ b/brainpy/_src/math/event/tests/event_csr_matmat_VS_csr_matmat.py @@ -0,0 +1,285 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +size = [ + (100, 100, 100), + (100, 1000, 100), + (1000, 1000, 100), + (1000, 1000, 1000), + (100, 10000, 100), + (10000, 100, 1000), + (1000, 100, 10000), + (10000, 10000, 1000), + (10000, 1000, 10000), + (10000, 10000, 10000), + (20000, 20000, 20000), +] + +values_type = [ + 'heter', + 'homo', +] +events_type = ['bool', + 'float', + ] +transpose = [ + # True, + False +] + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +print(bm.get_platform()) + + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + return r + + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.event.csrmm(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + return r + + +def test_sparse_csrmm(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) + matrix2_shape = (shape[1], shape[2]) + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') + matrix = rng.random(matrix2_shape) + matrix = bm.as_jax(matrix) + weight = 1. + + if events_type == 'float': + matrix = matrix.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time19 = time.time() + + result1 = result + + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time21 = time.time() + + result2 = result + + time22 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(event_csrmm(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time39 = time.time() + + csrmm_time1 = (time1 - time0) * 1000 + csrmm_time2 = (time3 - time2) * 1000 + csrmm_time3 = (time5 - time4) * 1000 + csrmm_time4 = (time7 - time6) * 1000 + csrmm_time5 = (time9 - time8) * 1000 + csrmm_time6 = (time11 - time10) * 1000 + csrmm_time7 = (time13 - time12) * 1000 + csrmm_time8 = (time15 - time14) * 1000 + csrmm_time9 = (time17 - time16) * 1000 + csrmm_time10 = (time19 - time18) * 1000 + event_csrmm_time1 = (time21 - time20) * 1000 + event_csrmm_time2 = (time23 - time22) * 1000 + event_csrmm_time3 = (time25 - time24) * 1000 + event_csrmm_time4 = (time27 - time26) * 1000 + event_csrmm_time5 = (time29 - time28) * 1000 + event_csrmm_time6 = (time31 - time30) * 1000 + event_csrmm_time7 = (time33 - time32) * 1000 + event_csrmm_time8 = (time35 - time34) * 1000 + event_csrmm_time9 = (time37 - time36) * 1000 + event_csrmm_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('csrmm_1: ', csrmm_time1, 'ms') + print('csrmm_3: ', csrmm_time3, 'ms') + print('csrmm_5: ', csrmm_time5, 'ms') + print('csrmm_7: ', csrmm_time7, 'ms') + print('csrmm_9: ', csrmm_time9, 'ms') + print('event_csrmm_1: ', event_csrmm_time1, 'ms') + print('event_csrmm_3: ', event_csrmm_time3, 'ms') + print('event_csrmm_5: ', event_csrmm_time5, 'ms') + print('event_csrmm_7: ', event_csrmm_time7, 'ms') + print('event_csrmm_9: ', event_csrmm_time9, 'ms') + + r = bm.allclose(result1, result2) + if not r: + print('result1: ', result1) + print('result2: ', result2) + + return csrmm_time1, csrmm_time2, csrmm_time3, csrmm_time4, csrmm_time5, \ + csrmm_time6, csrmm_time7, csrmm_time8, csrmm_time9, csrmm_time10, \ + event_csrmm_time1, event_csrmm_time2, event_csrmm_time3, event_csrmm_time4, event_csrmm_time5, \ + event_csrmm_time6, event_csrmm_time7, event_csrmm_time8, event_csrmm_time9, event_csrmm_time10 + + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame( + columns=['shape', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', + 'csrmm time1(ms)', 'csrmm time2(ms)', 'csrmm time3(ms)', 'csrmm time4(ms)', + 'csrmm time5(ms)', + 'csrmm time6(ms)', 'csrmm time7(ms)', 'csrmm time8(ms)', 'csrmm time9(ms)', + 'csrmm time10(ms)', + 'event_csrmm time1(ms)', 'event_csrmm time2(ms)', 'event_csrmm time3(ms)', 'event_csrmm time4(ms)', + 'event_csrmm time5(ms)', + 'event_csrmm time6(ms)', 'event_csrmm time7(ms)', 'event_csrmm time8(ms)', 'event_csrmm time9(ms)', + 'event_csrmm time10(ms)']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, csrmm_time_5, \ + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, csrmm_time_10, \ + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, event_csrmm_time_5, \ + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, event_csrmm_time_10 = test_sparse_csrmm( + shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.05, shape[0], shape[1], shape[2], 'cpu', _values_type, _events_type, + _transpose, + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, + csrmm_time_5, + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, + csrmm_time_10, + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, + event_csrmm_time_5, + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, + event_csrmm_time_10] + df.to_csv(f'{PATH}/csrmm_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape in size: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, csrmm_time_5, \ + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, csrmm_time_10, \ + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, event_csrmm_time_5, \ + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, event_csrmm_time_10 = test_sparse_csrmm( + shape, + _values_type, + _events_type, + _transpose) + # append to dataframe + df.loc[df.shape[0]] = [shape, 0.05, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, + _transpose, + csrmm_time_1, csrmm_time_2, csrmm_time_3, csrmm_time_4, + csrmm_time_5, + csrmm_time_6, csrmm_time_7, csrmm_time_8, csrmm_time_9, + csrmm_time_10, + event_csrmm_time_1, event_csrmm_time_2, event_csrmm_time_3, event_csrmm_time_4, + event_csrmm_time_5, + event_csrmm_time_6, event_csrmm_time_7, event_csrmm_time_8, event_csrmm_time_9, + event_csrmm_time_10] + df.to_csv(f'{PATH}/csrmm_gpu.csv', index=False) diff --git a/brainpy/_src/math/event/tests/test_event_csrmm.py b/brainpy/_src/math/event/tests/test_event_csrmm.py new file mode 100644 index 00000000..12a35ef3 --- /dev/null +++ b/brainpy/_src/math/event/tests/test_event_csrmm.py @@ -0,0 +1,276 @@ +# -*- coding: utf-8 -*- + +from functools import partial + +import jax +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +# bm.set_platform('gpu') + +seed = 1234 + + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +class Test_csrmm(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_csrmm, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones(indices.shape) * homo_data + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.event.csrmm(homo_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + c = bm.allclose(r1, r2, equal_nan=True) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_vmap(self, transpose, shape, homo_data): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # vmap 'data' + f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax([homo_data] * 10) + heter_data = bm.ones((10, indices.shape[0])) * homo_data + r1 = f1(heter_data) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + # vmap 'events' + heter_data = bm.ones(indices.shape) * homo_data + f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f4 = jax.vmap(partial(bm.event.csrmm, homo_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) + r3 = f3(matrix) + r4 = f4(matrix) + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() + if transpose else + ((dense * a) @ matrix).sum()), + argnums=0) + r1 = dense_f1(homo_data) + r2 = jax.grad(sum_op(bm.event.csrmm))( + bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() + if transpose else + ((dense * homo_data) @ m).sum()), + argnums=0) + r3 = dense_f2(matrix.astype(float)) + r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( + bm.asarray([homo_data]), indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter(self, transpose, shape): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.as_jax(rng.random(indices.shape)) + + r1 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r2 = bm.event.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_vmap(self, transpose, shape): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # vmap 'data' + f1 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f2 = jax.vmap(partial(bm.event.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax(rng.random((10, indices.shape[0]))) + r1 = f1(vmap_data) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + # vmap 'events' + heter_data = bm.ones(indices.shape) + f3 = jax.vmap(partial(bm.sparse.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + f4 = jax.vmap(partial(bm.event.csrmm, heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + matrix = bm.as_jax(rng.random((10, shape[1], shape[2])) < 0.1) + r3 = f3(matrix) + r4 = f4(matrix) + self.assertTrue(bm.allclose(r3, r4)) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_grad(self, transpose, shape): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + r1 = jax.grad(sum_op(bm.sparse.csrmm))( + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) + r2 = jax.grad(sum_op(bm.event.csrmm))( + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + r3 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r4 = jax.grad(sum_op(bm.event.csrmm), argnums=3)( + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index 14256cbc..13c9e1e2 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,6 +1,7 @@ # from ._coo_mv import * # from ._bsr_mv import * from .csr_mv import * +from .csr_mm import * from .utils import * from .bsr_mm import * from .jax_prim import * diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py new file mode 100644 index 00000000..47c24fa4 --- /dev/null +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -0,0 +1,227 @@ +# -*- coding: utf-8 -*- + + +from typing import Union, Tuple + +import jax +import numpy as np +from jax import numpy as jnp +from jax.experimental.sparse import csr +from jax.interpreters import ad + +from brainpy._src.dependency_check import import_taichi +from brainpy._src.math.interoperability import as_jax +from brainpy._src.math.ndarray import Array +from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) +from brainpy.errors import PackageMissingError + +ti = import_taichi(error_if_not_found=False) + +__all__ = [ + 'csrmm', +] + + +def csrmm( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + """ + Product of CSR sparse matrix and a dense matrix. + + Args: + data : array of shape ``(nse,)``. + indices : array of shape ``(nse,)`` + indptr : array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype`` + B : array of shape ``(shape[0] if transpose else shape[1], cols)`` and + dtype ``data.dtype`` + shape : length-2 tuple representing the matrix shape + transpose : boolean specifying whether to transpose the sparse matrix + before computing. + + Returns: + C : array of shape ``(shape[1] if transpose else shape[0], cols)`` + representing the matrix-matrix product. + """ + return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + + +def raw_csrmm_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + matrix: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + assert len(shape) == 2 + + indices = as_jax(indices) + indptr = as_jax(indptr) + matrix = as_jax(matrix) + data = jnp.atleast_1d(data) + + if matrix.dtype == jnp.bool_: + matrix = as_jax(matrix, dtype=data.dtype) + + if data.dtype != matrix.dtype: + raise TypeError('The types of data and vector should be the same. ' + f'But we got {data.dtype} != {matrix.dtype}.') + assert matrix.ndim == 2 + + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + assert indptr.shape[0] == shape[0] + 1 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + + assert matrix.shape[0] == (shape[0] if transpose else shape[1]) + + if indices.shape[0] == 0: + return [jnp.zeros(result_shape, dtype=data.dtype), ] + + # homo -> taichi, + # heter -> cusparse + if data.shape[0] != 1: + return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] + else: + if ti is None: + raise PackageMissingError.by_purpose('taichi', 'customzied sparse matrix multiplication') + if transpose: + prim = _csr_matmat_transpose_homo_p + else: + prim = _csr_matmat_homo_p + r = prim(indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)], + transpose=transpose, + shape=shape) + return [r[0] * data] + + +# taichi kernels +if ti is not None: + # @ti.kernel + # def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), + # col_indices: ti.types.ndarray(ndim=1), + # row_ptr: ti.types.ndarray(ndim=1), + # matrix: ti.types.ndarray(ndim=2), + # out: ti.types.ndarray(ndim=2)): + # for row_i in range(row_ptr.shape[0] - 1): + # for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + # col = col_indices[i] + # for j in range(out.shape[1]): + # out[col, j] += values[row_i] * matrix[row_i, j] + # + # @ti.kernel + # def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), + # col_indices: ti.types.ndarray(ndim=1), + # row_ptr: ti.types.ndarray(ndim=1), + # matrix: ti.types.ndarray(ndim=2), + # out: ti.types.ndarray(ndim=2)): + # for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): + # r = 0. + # for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + # r += values[j] * matrix[col_indices[j], col_k] + # out[row_i, col_k] = r + # + # # transpose heter + # _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter, + # gpu_kernel=_csr_matmat_transpose_heter) + # + # # no transpose heter + # _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, + # gpu_kernel=_csr_matmat_heter) + + @ti.kernel + def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + # matrix: (k, n) + # sparse matrix: (m, k) + n = out.shape[1] + m = row_ptr.shape[0] - 1 + for j in range(n): # parallize along the n dimension + for row_i in range(m): # loop along the m dimension + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[i], j] += matrix[row_i, j] + + + @ti.kernel + def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + m = row_ptr.shape[0] - 1 + n = matrix.shape[1] + for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[i], j] += matrix[row_i, j] + + + @ti.kernel + def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + # matrix: (k, n) + # sparse matrix: (m, k) + m, n = out.shape + for row_i, col_k in ti.ndrange(m, n): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r + + + def _csr_matmat_jvp_matrix(mat_dot, col_indices, row_ptr, matrix, *, outs, transpose, shape): + if transpose: + return _csr_matmat_transpose_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs) + else: + return _csr_matmat_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs) + + + def _csr_matmat_transpose( + ct, col_indices, row_ptr, matrix, *, outs, transpose, shape, + ): + if ad.is_undefined_primal(col_indices) or ad.is_undefined_primal(row_ptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + assert ad.is_undefined_primal(matrix) + ct_matrix = raw_csrmm_taichi(jnp.ones(1), col_indices, row_ptr, ct[0], + shape=shape, + transpose=not transpose) + return col_indices, row_ptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix[0]) + + + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(None, None, _csr_matmat_jvp_matrix) + prim.def_transpose_rule(_csr_matmat_transpose) + return prim + + + # transpose homo + _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) + + # no transpose homo + _csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) + + # heter CUSPARSE + _csr_matmat_cusparse_p = csr.csr_matmat_p + register_general_batching(_csr_matmat_cusparse_p) diff --git a/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py new file mode 100644 index 00000000..d40a9324 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/csr_matmat_VS_cusparse_csr_matmat.py @@ -0,0 +1,465 @@ +import os +import time + +import numpy as np + +os.environ["CUDA_VISIBLE_DEVICES"] = "2" + +import jax +import jax.numpy as jnp +import taichi as ti +from jax.experimental.sparse import csr + +import brainpy as bp +import brainpy.math as bm + +bm.set_platform('gpu') + +size = [ + (100, 100, 100), + (100, 1000, 100), + (1000, 1000, 100), + (1000, 1000, 1000), + (100, 10000, 100), + (10000, 100, 1000), + (1000, 100, 10000), + (10000, 10000, 1000), + (10000, 1000, 10000), + (10000, 10000, 10000), + (20000, 20000, 20000), +] + +values_type = [ + 'homo', + # 'heter' +] +events_type = ['float'] +transpose = [ + # True, + False +] +ITERATION = 10 +SPARSITY = 0.05 + +print(bm.get_platform()) + + +@ti.kernel +def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + # matrix: (k, n) + # sparse matrix: (m, k) + n = out.shape[1] + m = row_ptr.shape[0] - 1 + for j in range(n): # parallize along the n dimension + for row_i in range(m): # loop along the m dimension + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[i], j] += matrix[row_i, j] + + +@ti.kernel +def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + m = row_ptr.shape[0] - 1 + n = matrix.shape[1] + for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions + for i in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[i], j] += matrix[row_i, j] + + +@ti.kernel +def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + # matrix: (k, n) + # sparse matrix: (m, k) + m, n = out.shape + for row_i, col_k in ti.ndrange(m, n): + r = 0. + for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += matrix[col_indices[row_j], col_k] + out[row_i, col_k] = r + + +# transpose homo +_csr_matmat_transpose_homo_p = bm.XLACustomOp(cpu_kernel=_csr_matmat_transpose_homo_cpu, + gpu_kernel=_csr_matmat_transpose_homo_gpu) + +# no transpose homo +_csr_matmat_homo_p = bm.XLACustomOp(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) + + +def taichi_csrmm(weight, indices, indptr, matrix, shape, transpose): + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + matrix = bm.as_jax(matrix) + weight = jnp.atleast_1d(weight) + out_shape = shape[1] if transpose else shape[0] + result_shape = (out_shape, matrix.shape[1]) + if transpose: + prim = _csr_matmat_transpose_homo_p + else: + prim = _csr_matmat_homo_p + r = prim(indices, + indptr, + matrix, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)], + transpose=transpose, + shape=shape) + return r[0] + + +SHARED_MEM_SIZE = 256 + + +# @ti.kernel +# def _csr_matmat_homo2(col_indices: ti.types.ndarray(ndim=1), +# row_ptr: ti.types.ndarray(ndim=1), +# matrix: ti.types.ndarray(ndim=2), +# out: ti.types.ndarray(ndim=2)): +# m, n = out.shape +# l = col_indices.shape[0] +# ti.loop_config(block_dim=SHARED_MEM_SIZE) +# # for i_col, i_row in ti.ndrange(n, m): +# for i in range(m * n): +# indices_sm = ti.simt.block.SharedArray((SHARED_MEM_SIZE,), ti.int32) +# +# # one block threads compute will SHARED_MEM_SIZE columns +# i_row = i // SHARED_MEM_SIZE +# i_col = i % SHARED_MEM_SIZE +# +# index_start = row_ptr[i_row] +# end_border = row_ptr[i_row + 1] +# n_share = (end_border - index_start) // SHARED_MEM_SIZE +# n_last = end_border - index_start - n_share * SHARED_MEM_SIZE +# +# r = 0. +# for i_share in range(n_share): +# indices_sm[i_col] = col_indices[i_col + i_share * SHARED_MEM_SIZE] +# ti.simt.block.sync() +# # compute +# for j in range(SHARED_MEM_SIZE): +# r += matrix[indices_sm[j], i_col] +# indices_sm[i_col] = col_indices[ti.min(i_col + n_share * SHARED_MEM_SIZE, l)] +# ti.simt.block.sync() +# for j in range(n_last): +# r += matrix[indices_sm[j], i_col] +# out[i_row, i_col] += r + + +@ti.kernel +def _csr_matmat_homo2(col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + matrix: ti.types.ndarray(ndim=2), + out: ti.types.ndarray(ndim=2)): + m, n = out.shape + l = col_indices.shape[0] + ti.loop_config(block_dim=SHARED_MEM_SIZE) + + indices_sm = ti.simt.block.SharedArray((SHARED_MEM_SIZE,), ti.int32) + # for i_col, i_row in ti.ndrange(n, m): + for i in ti.ndrange(n * m): + # i_col = ti.global_thread_idx() % n + # i_row = ti.global_thread_idx() // n + i_col = i % n + i_row = i // n + i_share = i_col % SHARED_MEM_SIZE + + index_start = row_ptr[i_row] + end_border = row_ptr[i_row + 1] + n_share = (end_border - index_start) // SHARED_MEM_SIZE + n_last = end_border - index_start - n_share * SHARED_MEM_SIZE + + r = 0. + for k in range(n_share): + indices_sm[i_share] = col_indices[index_start + i_share + k * SHARED_MEM_SIZE] + ti.simt.block.sync() + for j in range(SHARED_MEM_SIZE): + r += matrix[indices_sm[j], i_col] + indices_sm[i_share] = col_indices[ti.min(index_start + i_share + n_share * SHARED_MEM_SIZE, l)] + ti.simt.block.sync() + for j in range(n_last): + r += matrix[indices_sm[j], i_col] + + # final results + out[i_row, i_col] += r + + +# no transpose homo +_csr_matmat_homo2_p = bm.XLACustomOp(gpu_kernel=_csr_matmat_homo2) + + +def taichi_csrmm2(weight, indices, indptr, matrix, shape, transpose): + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + matrix = bm.as_jax(matrix) + weight = jnp.atleast_1d(weight) + result_shape = (shape[1] if transpose else shape[0], matrix.shape[1]) + return _csr_matmat_homo2_p(indices, indptr, matrix, transpose=transpose, shape=shape, + outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)])[0] + + +def jaxlib_csrmm(weight, indices, indptr, matrix, shape, transpose): + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + matrix = bm.as_jax(matrix) + weight = jnp.atleast_1d(weight) + return csr.csr_matmat_p.bind(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + + +def generate_op(op): + def csrmm(weight, indices, indptr, matrix, shape, transpose): + r = 0 + for i in range(ITERATION): + t = op(weight, indices, indptr, matrix, shape=shape, transpose=transpose) + r += t + return r + + return jax.jit(csrmm, static_argnames=('shape', 'transpose')) + + +def run_spmm_homo(op, shape, transpose, use_heter_data=False): + bm.random.seed(1234) + matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) + matrix2_shape = (shape[1], shape[2]) + indices, indptr = bp.conn.FixedProb(SPARSITY, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') + matrix = bm.as_jax(bm.random.random(matrix2_shape)) + weight = 1. + if use_heter_data: + weight = bm.ones(indices.shape) * weight + + result = jax.block_until_ready(op(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + times = [] + for i in range(10): + time0 = time.time() + result = jax.block_until_ready(op(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) + time1 = time.time() + times.append(time1 - time0) + return np.asarray(times).mean(), result + + +bm.clear_taichi_aot_caches() +for shape in size: + for _transpose in transpose: + cusparse_times, cusparse_r = run_spmm_homo(generate_op(jaxlib_csrmm), shape, _transpose, use_heter_data=True) + homo1_times, homo1_r = run_spmm_homo(generate_op(taichi_csrmm), shape, _transpose) + homo2_times, homo2_r = run_spmm_homo(generate_op(taichi_csrmm2), shape, _transpose) + print(jnp.allclose(cusparse_r, homo1_r), jnp.allclose(cusparse_r, homo2_r)) + print(f'shape={shape}, transpose={_transpose}, cusparse/homo1 = {cusparse_times / homo1_times}, ' + f'cusparse/homo2 = {cusparse_times / homo2_times}') + print(homo2_r) + +# def test_sparse_csrmm(shape, values_type, events_type, transpose): +# rng = bm.random.RandomState(seed=1234) +# matrix1_shape = (shape[1], shape[0]) if transpose else (shape[0], shape[1]) +# matrix2_shape = (shape[1], shape[2]) +# indices, indptr = bp.conn.FixedProb(SPARSITY, seed=1234, allow_multi_conn=True)(*matrix1_shape).require('pre2post') +# matrix = rng.random(matrix2_shape) +# matrix = bm.as_jax(matrix) +# weight = 1. +# +# heter_data = bm.ones(indices.shape) * weight +# +# if events_type == 'float': +# matrix = matrix.astype(bm.float32) +# # if values_type == 'heter': +# # weight = heter_data +# +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# +# time0 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time1 = time.time() +# +# time2 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time3 = time.time() +# +# time4 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time5 = time.time() +# +# time6 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time7 = time.time() +# +# time8 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time9 = time.time() +# +# time10 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time11 = time.time() +# +# time12 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time13 = time.time() +# +# time14 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time15 = time.time() +# +# time16 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time17 = time.time() +# +# time18 = time.time() +# result = jax.block_until_ready( +# csrmm_taichi(weight, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time19 = time.time() +# +# result1 = result +# +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# +# time20 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time21 = time.time() +# +# result2 = result +# +# time22 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time23 = time.time() +# +# time24 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time25 = time.time() +# +# time26 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time27 = time.time() +# +# time28 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time29 = time.time() +# +# time30 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time31 = time.time() +# +# time32 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time33 = time.time() +# +# time34 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time35 = time.time() +# +# time36 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time37 = time.time() +# +# time38 = time.time() +# result = jax.block_until_ready(csrmm(heter_data, indices, indptr, matrix, shape=matrix1_shape, transpose=transpose)) +# time39 = time.time() +# +# taichi_aot_time1 = (time1 - time0) * 1000 +# taichi_aot_time2 = (time3 - time2) * 1000 +# taichi_aot_time3 = (time5 - time4) * 1000 +# taichi_aot_time4 = (time7 - time6) * 1000 +# taichi_aot_time5 = (time9 - time8) * 1000 +# taichi_aot_time6 = (time11 - time10) * 1000 +# taichi_aot_time7 = (time13 - time12) * 1000 +# taichi_aot_time8 = (time15 - time14) * 1000 +# taichi_aot_time9 = (time17 - time16) * 1000 +# taichi_aot_time10 = (time19 - time18) * 1000 +# brainpy_time1 = (time21 - time20) * 1000 +# brainpy_time2 = (time23 - time22) * 1000 +# brainpy_time3 = (time25 - time24) * 1000 +# brainpy_time4 = (time27 - time26) * 1000 +# brainpy_time5 = (time29 - time28) * 1000 +# brainpy_time6 = (time31 - time30) * 1000 +# brainpy_time7 = (time33 - time32) * 1000 +# brainpy_time8 = (time35 - time34) * 1000 +# brainpy_time9 = (time37 - time36) * 1000 +# brainpy_time10 = (time39 - time38) * 1000 +# print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) +# print('taichi_aot_1: ', taichi_aot_time1, 'ms') +# print('taichi_aot_3: ', taichi_aot_time3, 'ms') +# print('taichi_aot_5: ', taichi_aot_time5, 'ms') +# print('taichi_aot_7: ', taichi_aot_time7, 'ms') +# print('taichi_aot_9: ', taichi_aot_time9, 'ms') +# print('brainpylib_1: ', brainpy_time1, 'ms') +# print('brainpylib_3: ', brainpy_time3, 'ms') +# print('brainpylib_5: ', brainpy_time5, 'ms') +# print('brainpylib_7: ', brainpy_time7, 'ms') +# print('brainpylib_9: ', brainpy_time9, 'ms') +# print(bm.allclose(result1, result2)) +# +# return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5, \ +# taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10, \ +# brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ +# brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +# PATH = os.path.dirname(os.path.abspath(__file__)) +# +# # init dataframe +# df = pd.DataFrame( +# columns=['s', 'p', 'shape[0]', 'shape[1]', 'shape[2]', 'backend', 'values type', 'events type', 'transpose', +# 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', +# 'taichi aot time5(ms)', +# 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', +# 'taichi aot time10(ms)', +# 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', +# 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) +# +# for shape in size: +# for _values_type in values_type: +# for _events_type in events_type: +# for _transpose in transpose: +# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, \ +# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, \ +# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ +# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmm(shape, +# _values_type, +# _events_type, +# _transpose) +# # append to dataframe +# df.loc[df.shape[0]] = [shape, 0.5, shape[0], shape[1], shape[2], 'gpu', _values_type, _events_type, +# _transpose, +# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, +# taichi_aot_time_5, +# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, +# taichi_aot_time_10, +# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, +# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] +# +# print(shape, _values_type, _events_type, _transpose) +# a = np.asarray([taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, +# taichi_aot_time_5, +# taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, +# taichi_aot_time_10]) +# b = np.asarray([brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, +# brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10]) +# print(a) +# print(b) +# print(a.sum() / b.sum()) +# df.to_csv(f'{PATH}/csrmm_{bm.get_platform()}.csv', index=False) diff --git a/brainpy/_src/math/sparse/tests/test_csrmm.py b/brainpy/_src/math/sparse/tests/test_csrmm.py new file mode 100644 index 00000000..e4346c84 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/test_csrmm.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- + +from functools import partial + +import jax +from absl.testing import parameterized + +import brainpy as bp +import brainpy.math as bm + +# bm.set_platform('gpu') + +seed = 1234 + + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +class Test_csrmm(parameterized.TestCase): + def __init__(self, *args, platform='cpu', **kwargs): + super(Test_csrmm, self).__init__(*args, **kwargs) + + print() + bm.set_platform(platform) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo(self, transpose, shape, homo_data): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones(indices.shape) * homo_data + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.sparse.csrmm(homo_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + c = bm.allclose(r1, r2, equal_nan=True) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_vmap(self, transpose, shape, homo_data): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + heter_data = bm.ones((10, indices.shape[0])) * homo_data + dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, + shape=(shape[1], shape[0]) if transpose else ( + shape[0], shape[1])))(heter_data) + + # vmap 'data' + f1 = jax.vmap(lambda a: (a.T @ matrix) if transpose else (a @ matrix)) + f2 = jax.vmap(partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose)) + vmap_data = bm.as_jax([homo_data] * 10) + + r1 = f1(dense) + r2 = f2(vmap_data) + self.assertTrue(bm.allclose(r1, r2)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + homo_data=[-1., 1.] + ) + def test_homo_grad(self, transpose, shape, homo_data): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + dense = bm.sparse.csr_to_dense(bm.ones(indices.shape).value, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + heter_data = bm.as_jax(rng.random((indices.shape))) + # matrix + matrix = rng.random((shape[1], shape[2])) < 0.1 + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: (((dense.T * a) @ matrix).sum() + if transpose else + ((dense * a) @ matrix).sum()), + argnums=0) + r1 = dense_f1(homo_data) + r2 = jax.grad(sum_op(bm.sparse.csrmm))( + bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad events matrix + dense_f2 = jax.grad(lambda m: (((dense.T * homo_data) @ m).sum() + if transpose else + ((dense * homo_data) @ m).sum()), + argnums=0) + r3 = dense_f2(matrix.astype(float)) + r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + bm.asarray([homo_data]), indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter(self, transpose, shape): + print(f'test_homo: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + heter_data = bm.as_jax(rng.random(indices.shape)) + + dense = bm.sparse.csr_to_dense(heter_data, indices, indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + + r1 = (dense.T @ matrix) if transpose else (dense @ matrix) + r2 = bm.sparse.csrmm(heter_data, indices, indptr, matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + print(r2) + print(r1.shape, '-', r2.shape) + c = bm.allclose(r1, r2, equal_nan=True) + if not c: + print(r1 - r2) + self.assertTrue(c) + + bm.clear_buffer_memory() + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_vmap(self, transpose, shape): + print(f'test_homo_vmap: transpose: {transpose} shape = {shape}') + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # matrix + rng = bm.random.RandomState(seed=seed) + matrix = rng.random((shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + heter_data = bm.as_jax(rng.random((10, indices.shape[0]))) + dense = jax.vmap(lambda a: bm.sparse.csr_to_dense(a, indices, indptr, + shape=(shape[1], shape[0]) if transpose else ( + shape[0], shape[1])))(heter_data) + + f1 = lambda a: (a.T @ matrix) if transpose else (a @ matrix) + f2 = partial(bm.sparse.csrmm, indices=indices, indptr=indptr, matrix=matrix, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose) + r1 = jax.vmap(f1)(dense) + r2 = jax.vmap(f2)(heter_data) + + self.assertTrue(bm.allclose(r1, r2, equal_nan=True)) + + @parameterized.product( + transpose=[True, False], + shape=[(50, 50, 50), (100, 50, 100), (10, 1000, 10), (2, 2000, 2)], + ) + def test_heter_grad(self, transpose, shape): + print(f'test_homo_grad: transpose: {transpose} shape = {shape}') + rng = bm.random.RandomState(seed=seed) + conn = bp.conn.FixedProb(0.3) + + # csr matrix + indices, indptr = conn(shape[1], shape[0]).require('pre2post') if transpose else conn(shape[0], + shape[1]).require( + 'pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + + heter_data = bm.as_jax(rng.random((indices.shape))) + dense = bm.sparse.csr_to_dense(heter_data, + indices, + indptr, + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1])) + # matrix + matrix = rng.random((shape[1], shape[2])) + matrix = bm.as_jax(matrix) + + # grad data + dense_f1 = jax.grad(lambda a: ((a.T @ matrix).sum() + if transpose else + (a @ matrix).sum()), + argnums=0) + r1 = dense_f1(dense) + r2 = jax.grad(sum_op(bm.sparse.csrmm))( + heter_data, indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), + transpose=transpose + ) + rows, cols = bm.sparse.csr_to_coo(indices, indptr) + if transpose: + r1 = r1[cols, rows] + else: + r1 = r1[rows, cols] + print(r1 - r2) + + self.assertTrue(bm.allclose(r1, r2)) + + # grad matrix + dense_f2 = jax.grad(lambda m: ((dense.T @ m).sum() + if transpose else + (dense @ m).sum())) + r3 = dense_f2(matrix) + r4 = jax.grad(sum_op(bm.sparse.csrmm), argnums=3)( + heter_data, indices, indptr, matrix.astype(float), + shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]), transpose=transpose + ) + + self.assertTrue(bm.allclose(r3, r4)) + + bm.clear_buffer_memory() diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 02e98b8f..3b4b5ed1 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,3 +1,4 @@ from brainpy._src.math.event import ( csrmv as csrmv, + csrmm as csrmm, ) diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index aa86679e..8a209901 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -4,6 +4,9 @@ ) from brainpy._src.math.sparse import ( csrmv, + csrmm, + + seg_matmul, csr_to_dense as csr_to_dense, csr_to_coo as csr_to_coo,