From 26e9820ea892754fe064f8673902e12efb843c33 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Thu, 30 Nov 2023 18:16:10 +0800 Subject: [PATCH] Add sparse csr matvec using taichi customized op --- brainpy/_src/math/event/_csr_matvec_taichi.py | 582 +++++++++--------- brainpy/_src/math/sparse/__init__.py | 1 + brainpy/_src/math/sparse/_csr_mv_taichi.py | 197 ++++++ .../math/sparse/tests/test_csrmv_taichi.py | 60 ++ brainpy/math/sparse.py | 1 + 5 files changed, 551 insertions(+), 290 deletions(-) create mode 100644 brainpy/_src/math/sparse/_csr_mv_taichi.py create mode 100644 brainpy/_src/math/sparse/tests/test_csrmv_taichi.py diff --git a/brainpy/_src/math/event/_csr_matvec_taichi.py b/brainpy/_src/math/event/_csr_matvec_taichi.py index 709cfa1ee..ad168def1 100644 --- a/brainpy/_src/math/event/_csr_matvec_taichi.py +++ b/brainpy/_src/math/event/_csr_matvec_taichi.py @@ -13,7 +13,7 @@ from brainpy._src.math.interoperability import as_jax from brainpy._src.math.op_register import (XLACustomOp, - register_general_batching) + register_general_batching) from brainpy._src.math.sparse._csr_mv import csrmv as normal_csrmv from brainpy._src.math.sparse._utils import csr_to_coo from brainpy._src.dependency_check import (import_brainpylib_cpu_ops, @@ -24,319 +24,321 @@ 'csrmv_taichi' ] +_event_csr_matvec_p = None + @ti.kernel -def event_csr_matvec_cpu_transpose(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - bool_param_list: ti.types.ndarray(ndim=1), - shape_list: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - is_event_type_bool_value = bool_param_list[1] - is_heter_value = bool_param_list[2] - if is_event_type_bool_value: # type of events is boolean - if is_heter_value: # heter - ti.loop_config(serialize=True) - for row_i in range(events.shape[0]): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - - else: # homo - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(events.shape[0]): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - else: # type of events is not boolean - if is_heter_value: # heter - ti.loop_config(serialize=True) - for row_i in range(events.shape[0]): - if events[row_i] > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - else: # homo - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(events.shape[0]): - if events[row_i] > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value +def _event_csr_matvec_cpu_transpose(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + bool_param_list: ti.types.ndarray(ndim=1), + shape_list: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + is_event_type_bool_value = bool_param_list[1] + is_heter_value = bool_param_list[2] + if is_event_type_bool_value: # type of events is boolean + if is_heter_value: # heter + ti.loop_config(serialize=True) + for row_i in range(events.shape[0]): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + else: # homo + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(events.shape[0]): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + else: # type of events is not boolean + if is_heter_value: # heter + ti.loop_config(serialize=True) + for row_i in range(events.shape[0]): + if events[row_i] > 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + else: # homo + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(events.shape[0]): + if events[row_i] > 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value @ti.kernel -def event_csr_matvec_cpu(values: ti.types.ndarray(ndim=1), +def _event_csr_matvec_cpu(values: ti.types.ndarray(ndim=1), indices: ti.types.ndarray(ndim=1), indptr: ti.types.ndarray(ndim=1), events: ti.types.ndarray(ndim=1), bool_param_list: ti.types.ndarray(ndim=1), shape_list: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): - is_event_type_bool_value = bool_param_list[1] - is_heter_value = bool_param_list[2] - if is_event_type_bool_value: # type of events is boolean - if is_heter_value: # heter - ti.loop_config(serialize=True) - for row_i in range(shape_list[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - else: # homo - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(shape_list[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - - else: # type of events is not boolean - if is_heter_value: # heter - ti.loop_config(serialize=True) - for row_i in range(shape_list[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] > 0.: - r += values[j] - out[row_i] = r - - else: # homo - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(shape_list[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] > 0.: - r += value - out[row_i] = r + is_event_type_bool_value = bool_param_list[1] + is_heter_value = bool_param_list[2] + if is_event_type_bool_value: # type of events is boolean + if is_heter_value: # heter + ti.loop_config(serialize=True) + for row_i in range(shape_list[0]): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += values[j] + out[row_i] = r + + else: # homo + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(shape_list[0]): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += value + out[row_i] = r + + else: # type of events is not boolean + if is_heter_value: # heter + ti.loop_config(serialize=True) + for row_i in range(shape_list[0]): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] > 0.: + r += values[j] + out[row_i] = r + + else: # homo + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(shape_list[0]): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] > 0.: + r += value + out[row_i] = r -@ti.kernel -def event_csr_matvec_gpu_transpose(values: ti.types.ndarray(), - indices: ti.types.ndarray(), - indptr: ti.types.ndarray(), - events: ti.types.ndarray(), - bool_param_list: ti.types.ndarray(), - shape_list: ti.types.ndarray(ndim=1), - out: ti.types.ndarray()): - is_event_type_bool_value = bool_param_list[1] - is_heter_value = bool_param_list[2] - if is_event_type_bool_value: # type of events is boolean - if is_heter_value: # heter - for row_i in range(events): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - - else: # homo - value = values[0] - for row_i in range(events): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - else: # type of events is not boolean - if is_heter_value: # heter - for row_i in range(events): - if events[row_i] > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - else: # homo - value = values[0] - for row_i in range(events): - if events[row_i] > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value +@ti.kernel +def _event_csr_matvec_gpu_transpose(values: ti.types.ndarray(), + indices: ti.types.ndarray(), + indptr: ti.types.ndarray(), + events: ti.types.ndarray(), + bool_param_list: ti.types.ndarray(), + shape_list: ti.types.ndarray(ndim=1), + out: ti.types.ndarray()): + is_event_type_bool_value = bool_param_list[1] + is_heter_value = bool_param_list[2] + if is_event_type_bool_value: # type of events is boolean + if is_heter_value: # heter + for row_i in range(events): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + else: # homo + value = values[0] + for row_i in range(events): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + else: # type of events is not boolean + if is_heter_value: # heter + for row_i in range(events): + if events[row_i] > 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + else: # homo + value = values[0] + for row_i in range(events): + if events[row_i] > 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value @ti.kernel -def event_csr_matvec_gpu(values: ti.types.ndarray(), +def _event_csr_matvec_gpu(values: ti.types.ndarray(), indices: ti.types.ndarray(), indptr: ti.types.ndarray(), events: ti.types.ndarray(), bool_param_list: ti.types.ndarray(), shape_list: ti.types.ndarray(ndim=1), out: ti.types.ndarray()): - is_event_type_bool_value = bool_param_list[1] - is_heter_value = bool_param_list[2] - if is_event_type_bool_value: # type of events is boolean - if is_heter_value: # heter - for row_i in range(shape_list[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - else: # homo - value = values[0] - for row_i in range(shape_list[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - - else: # type of events is not boolean - if is_heter_value: # heter - for row_i in range(shape_list[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] > 0.: - r += values[j] - out[row_i] = r - - else: # homo - value = values[0] - for row_i in range(shape_list[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] > 0.: - r += value - out[row_i] = r - -event_csr_matvec_p = None + is_event_type_bool_value = bool_param_list[1] + is_heter_value = bool_param_list[2] + if is_event_type_bool_value: # type of events is boolean + if is_heter_value: # heter + for row_i in range(shape_list[0]): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += values[j] + out[row_i] = r + + else: # homo + value = values[0] + for row_i in range(shape_list[0]): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += value + out[row_i] = r + + else: # type of events is not boolean + if is_heter_value: # heter + for row_i in range(shape_list[0]): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] > 0.: + r += values[j] + out[row_i] = r + + else: # homo + value = values[0] + for row_i in range(shape_list[0]): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] > 0.: + r += value + out[row_i] = r + + + def _event_matvec_jvp( - primals, tangents, + primals, tangents, ): - values, indices, indptr, events, bool_param_list, shape_list = primals - values_dot, indices_dot, indptr_dot, events_dot, bool_param_list_dot, shape_list_dot = tangents - - - - r = event_csr_matvec_p(values, - indices, - indptr, - events, - bool_param_list, - shape_list, - outs=[jax.ShapeDtypeStruct(shape=(shape_list[1] if bool_param_list[0] else shape_list[0],), dtype=values.dtype)]) - - assert type(values_dot) is ad.Zero - assert type(indices_dot) is ad.Zero - assert type(indptr_dot) is ad.Zero - assert type(events_dot) is ad.Zero - - if type(values_dot) is ad.Zero: - if type(events_dot) is ad.Zero: - raise ValueError - # TODO: implement sparse csr matvec first - - elif type(events_dot) is ad.Zero: - dr = event_csr_matvec_p(values_dot, - indices, - indptr, - events, - bool_param_list, - shape_list, - outs=[jax.ShapeDtypeStruct(shape=(shape_list[1] if bool_param_list[0] else shape_list[0],), dtype=values.dtype)]) - - return r, dr + values, indices, indptr, events, bool_param_list, shape_list = primals + values_dot, indices_dot, indptr_dot, events_dot, bool_param_list_dot, shape_list_dot = tangents + + r = _event_csr_matvec_p(values, + indices, + indptr, + events, + bool_param_list, + shape_list, + outs=[jax.ShapeDtypeStruct(shape=(shape_list[1] if bool_param_list[0] else shape_list[0],), + dtype=values.dtype)]) + + assert type(values_dot) is ad.Zero + assert type(indices_dot) is ad.Zero + assert type(indptr_dot) is ad.Zero + assert type(events_dot) is ad.Zero + + if type(values_dot) is ad.Zero: + if type(events_dot) is ad.Zero: + raise ValueError + # TODO: implement sparse csr matvec first + + elif type(events_dot) is ad.Zero: + dr = _event_csr_matvec_p(values_dot, + indices, + indptr, + events, + bool_param_list, + shape_list, + outs=[jax.ShapeDtypeStruct( + shape=(shape_list[1] if bool_param_list[0] else shape_list[0],), + dtype=values.dtype)]) + + return r, dr def csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False + data: Union[float, jax.Array], + indices: jax.Array, + indptr: jax.Array, + events: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False ) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError('indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError('indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # if the shape of indices is (0,), then we return a zero vector - if indices.shape[0] == 0: - return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - - bool_param_list = jnp.array([transpose, events.dtype == jnp.bool_, data.shape[0] > 1]) - shape_list = jnp.array(shape) - - global event_csr_matvec_p - if transpose: - event_csr_matvec_p = XLACustomOp(cpu_kernel=event_csr_matvec_cpu_transpose, gpu_kernel=event_csr_matvec_gpu_transpose) - else: - event_csr_matvec_p = XLACustomOp(cpu_kernel=event_csr_matvec_cpu, gpu_kernel=event_csr_matvec_gpu) - - - # computing - return event_csr_matvec_p(data, - indices, - indptr, - events, - bool_param_list, - shape_list, - outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)]) - + """Product of a sparse CSR matrix and a dense event vector. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + events: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + If ``transpose=True``, the operator will compute based on the + event-driven property of the ``events`` vector. + + Returns + ------- + y : Array + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + data = as_jax(data) + indices = as_jax(indices) + indptr = as_jax(indptr) + events = as_jax(events) + # checking + data = jnp.atleast_1d(data) + if np.ndim(data) == 1: + if data.shape[0] not in [1, indices.shape[0]]: + raise ValueError('The size of data should be 1 or be consistent with indices.' + f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') + else: + raise ValueError('data should be a scalar or 1D vector. ' + f'But we got {np.ndim(data)}-D array.') + if np.ndim(indices) != 1: + raise ValueError('indices should be a 1D vector with integer type.') + if np.ndim(indptr) != 1: + raise ValueError('indptr should be a 1D vector with integer type.') + if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: + raise ValueError( + 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') + if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: + raise ValueError( + 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') + if np.ndim(events) != 1: + raise ValueError('events should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if transpose: + if events.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') + else: + if events.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') + + # if the shape of indices is (0,), then we return a zero vector + if indices.shape[0] == 0: + return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) + + bool_param_list = jnp.array([transpose, events.dtype == jnp.bool_, data.shape[0] > 1]) + shape_list = jnp.array(shape) + + global _event_csr_matvec_p + if transpose: + _event_csr_matvec_p = XLACustomOp(cpu_kernel=_event_csr_matvec_cpu_transpose, + gpu_kernel=_event_csr_matvec_gpu_transpose) + else: + _event_csr_matvec_p = XLACustomOp(cpu_kernel=_event_csr_matvec_cpu, + gpu_kernel=_event_csr_matvec_gpu) + + # computing + return _event_csr_matvec_p(data, + indices, + indptr, + events, + bool_param_list, + shape_list, + outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)]) diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d45f2c80b..cd94d0621 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,6 +1,7 @@ from ._coo_mv import * from ._csr_mv import * +from ._csr_mv_taichi import * from ._utils import * from ._bsr_mv import * from ._bsr_mm import * diff --git a/brainpy/_src/math/sparse/_csr_mv_taichi.py b/brainpy/_src/math/sparse/_csr_mv_taichi.py new file mode 100644 index 000000000..ccdc9b8ce --- /dev/null +++ b/brainpy/_src/math/sparse/_csr_mv_taichi.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- + + +from functools import partial +from typing import Union, Tuple + +import jax +import numba +import numpy as np +import taichi as ti +from jax import core, dtypes +from jax import numpy as jnp +from jax.interpreters import ad, mlir, xla +from jax.lib import xla_client +from jaxlib import gpu_sparse + +from brainpy._src.math.interoperability import as_jax +from brainpy._src.math.ndarray import Array +from brainpy._src.math.op_register import XLACustomOp +from brainpy._src.math.sparse._utils import csr_to_coo +from brainpy._src.dependency_check import import_brainpylib_gpu_ops +from brainpy.errors import GPUOperatorNotFound + +__all__ = [ + 'csrmv_taichi', +] + +_event_csr_matvec_p = None + +@ti.kernel +def _sparse_csr_matvec_cpu_transpose(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + shape: ti.types.ndarray(ndim=1), + transpose: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + if values.shape[0] == 1: + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(shape[0]): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += value * vector[row_i] + + else: + ti.loop_config(serialize=True) + for row_i in range(shape[0]): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += values[j] * vector[row_i] + +@ti.kernel +def _sparse_csr_matvec_cpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + shape: ti.types.ndarray(ndim=1), + transpose: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + if values.shape[0] == 1: + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(shape[0]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += value * vector[col_indices[j]] + out[row_i] = r + + else: + ti.loop_config(serialize=True) + for row_i in range(shape[0]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + out[row_i] = r + + +@ti.kernel +def _sparse_csr_matvec_gpu_transpose(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + shape: ti.types.ndarray(ndim=1), + transpose: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + if values.shape[0] == 1: + value = values[0] + for row_i in range(shape[0]): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += value * vector[row_i] + + else: + for row_i in range(shape[0]): + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + out[col_indices[j]] += values[j] * vector[row_i] + +@ti.kernel +def _sparse_csr_matvec_gpu(values: ti.types.ndarray(ndim=1), + col_indices: ti.types.ndarray(ndim=1), + row_ptr: ti.types.ndarray(ndim=1), + vector: ti.types.ndarray(ndim=1), + shape: ti.types.ndarray(ndim=1), + transpose: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + if values.shape[0] == 1: + value = values[0] + for row_i in range(shape[0]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += value * vector[col_indices[j]] + out[row_i] = r + + else: + for row_i in range(shape[0]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + out[row_i] = r + + +def csrmv_taichi( + data: Union[float, jnp.ndarray, Array], + indices: Union[jnp.ndarray, Array], + indptr: Union[jnp.ndarray, Array], + vector: Union[jnp.ndarray, Array], + *, + shape: Tuple[int, int], + transpose: bool = False, +): + """Product of CSR sparse matrix and a dense vector using cuSPARSE algorithm. + + This function supports JAX transformations, including `jit()`, `grad()`, + `vmap()` and `pmap()`. + + Parameters + ---------- + data: ndarray, float + An array of shape ``(nse,)``. + indices: ndarray + An array of shape ``(nse,)``. + indptr: ndarray + An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. + vector: ndarray + An array of shape ``(shape[0] if transpose else shape[1],)`` + and dtype ``data.dtype``. + shape: tuple of int + A length-2 tuple representing the matrix shape. + transpose: bool + A boolean specifying whether to transpose the sparse matrix + before computing. + + Returns + ------- + y : ndarry + The array of shape ``(shape[1] if transpose else shape[0],)`` representing + the matrix vector product. + """ + + data = jnp.atleast_1d(as_jax(data)) + indices = as_jax(indices) + indptr = as_jax(indptr) + vector = as_jax(vector) + + if vector.dtype == jnp.bool_: + vector = as_jax(vector, dtype=data.dtype) + + if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]: + raise TypeError('Only support float16, float32 or float64 type. ' + f'But we got {data.dtype}.') + if data.dtype != vector.dtype: + raise TypeError('The types of data and vector should be the same. ' + f'But we got {data.dtype} != {vector.dtype}.') + assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 + if not jnp.issubdtype(indices.dtype, jnp.integer): + raise ValueError('indices should be a 1D vector with integer type.') + if not jnp.issubdtype(indptr.dtype, jnp.integer): + raise ValueError('indptr should be a 1D vector with integer type.') + out_shape = shape[1] if transpose else shape[0] + + global _event_csr_matvec_p + if transpose: + _event_csr_matvec_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_cpu_transpose, + gpu_kernel=_sparse_csr_matvec_gpu_transpose) + else: + _event_csr_matvec_p = XLACustomOp(cpu_kernel=_sparse_csr_matvec_cpu, + gpu_kernel=_sparse_csr_matvec_gpu) + + shape_list = jnp.array(shape) + is_transpose = jnp.array(transpose) + + return _event_csr_matvec_p(data, + indices, + indptr, + vector, + shape_list, + is_transpose, + outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)] + ) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py new file mode 100644 index 000000000..8aae157b0 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/test_csrmv_taichi.py @@ -0,0 +1,60 @@ +# -*- coding: utf-8 -*- + +from functools import partial + +import jax +import pytest +# from absl.testing import parameterized +import platform +import brainpy as bp +import brainpy.math as bm + +# is_manual_test = False +# if platform.system() == 'Windows' and not is_manual_test: +# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) + +vector_csr_matvec = partial(bm.sparse.csrmv, method='vector') + +homo_datas=[-1., 0., 0.1, 1.] +shapes=[(100, 200), (10, 1000), (2, 2000)] + +def test_homo(shape, homo_data): + print(f'test_homo: shape = {shape}, homo_data = {homo_data}') + conn = bp.conn.FixedProb(0.1) + + # matrix + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + # vector + rng = bm.random.RandomState(123) + vector = rng.random(shape[1]) + vector = bm.as_jax(vector) + + r1 = vector_csr_matvec(homo_data, indices, indptr, vector, shape=shape) + r2 = bm.sparse.csrmv_taichi(homo_data, indices, indptr, vector, shape=shape) + + assert(bm.allclose(r1, r2[0])) + +def test_heter(shape): + print(f'test_homo: shape = {shape}') + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + heter_data = bm.as_jax(rng.random(indices.shape)) + vector = bm.as_jax(rng.random(shape[1])) + + r1 = vector_csr_matvec(heter_data, indices, indptr, vector, shape=shape) + r2 = bm.sparse.csrmv_taichi(heter_data, indices, indptr, vector, shape=shape) + + assert(bm.allclose(r1, r2[0])) + +# for shape in shapes: +# for homo_data in homo_datas: +# test_homo(shape, homo_data) + +for shape in shapes: + test_heter(shape) \ No newline at end of file diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 1380a9e9c..97c585746 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,5 +1,6 @@ from brainpy._src.math.sparse import ( csrmv, + csrmv_taichi, coomv, seg_matmul,