diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 05a7c79c..89eb9c17 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -16,7 +16,7 @@ ] _minimal_brainpylib_version = '0.2.6' -_minimal_taichi_version = (1, 7, 0) +_minimal_taichi_version = (1, 7, 2) numba = None taichi = None diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index 33677691..50bf15cb 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -8,12 +8,12 @@ from jax import numpy as jnp from jax.interpreters import ad from jax.experimental.sparse import csr +from braintaichi import event_csrmm as bt_event_csrmm from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) -from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm from brainpy._src.math.sparse.utils import csr_to_coo from brainpy._src.math.defaults import float_ @@ -49,262 +49,4 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product product. """ - return raw_event_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] - - -def raw_event_csrmm_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -): - assert len(shape) == 2 - - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - - indices = as_jax(indices) - indptr = as_jax(indptr) - matrix = as_jax(matrix) - - assert data.ndim == indices.ndim == indptr.ndim == 1 - assert matrix.ndim == 2 - assert indptr.shape[0] == shape[0] + 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - # if the shape of indices is (0,), then we return a zero matrix - if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype), ] - - assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - - # homo -> taichi - # heter -> cusparse - if data.shape[0] != 1: - if matrix.dtype == jnp.bool_: - # change dtype to float - matrix = matrix.astype(float_) - return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] - else: - if transpose: - if matrix.dtype == jnp.bool_: - prim = _event_csr_matmat_transpose_homo_p - else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) - else: - if matrix.dtype == jnp.bool_: - prim = _event_csr_matmat_bool_homo_p - else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) - return prim(data, - indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -# taichi kernels - -@ti.kernel -def _event_csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += values[j] * matrix[row_j, col_i] - - -@ti.kernel -def _event_csr_matmat_transpose_bool_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += values[j] * matrix[row_j, col_i] - - -@ti.kernel -def _event_csr_matmat_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] * matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - -@ti.kernel -def _event_csr_matmat_bool_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] * matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - -@ti.kernel -def _event_csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += value * matrix[row_j, col_i] - - -@ti.kernel -def _event_csr_matmat_transpose_bool_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += value * matrix[row_j, col_i] - - -@ti.kernel -def _event_csr_matmat_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value - - -@ti.kernel -def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value - - -def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) - - -def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return normal_csrmm(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) - - -def _event_csr_matmat_transpose( - ct, data, indices, indptr, matrix, *, outs, transpose, shape, -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(matrix): - ct_matrix = raw_event_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = \ - raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] - ct_data = jnp.sum(ct[0] * ct_data) - else: # heter - matrix = jnp.asarray(matrix) - row, col = csr_to_coo(indices, indptr) - ct_data = (ct[0][row] * matrix[col]).sum(1) - return ct_data, indices, indptr, matrix - - -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matmat_jvp_values, None, None, _event_csr_matmat_jvp_matrix) - prim.def_transpose_rule(_event_csr_matmat_transpose) - return prim - - -# transpose heter -_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter, - gpu_kernel=_event_csr_matmat_transpose_heter) - -# no transpose heter -_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter, - gpu_kernel=_event_csr_matmat_heter) - -# transpose homo -_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo, - gpu_kernel=_event_csr_matmat_transpose_homo) - -# no transpose homo -_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo, - gpu_kernel=_event_csr_matmat_homo) - -# bool transpose heter -_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter, - gpu_kernel=_event_csr_matmat_transpose_bool_heter) - -# bool no transpose heter -_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter, - gpu_kernel=_event_csr_matmat_bool_heter) - -# bool transpose homo -_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo, - gpu_kernel=_event_csr_matmat_transpose_bool_homo) - -# bool no transpose homo -_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo, - gpu_kernel=_event_csr_matmat_bool_homo) - -# heter CUSPARSE -_csr_matmat_cusparse_p = csr.csr_matmat_p -register_general_batching(_csr_matmat_cusparse_p) \ No newline at end of file + return bt_event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py index d4478345..85bcedfd 100644 --- a/brainpy/_src/math/event/csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -13,16 +13,9 @@ from typing import Union, Tuple import jax -import jax.numpy as jnp -import numpy as np -from jax.interpreters import ad +from braintaichi import event_csrmv as bt_event_csrmv from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi -from brainpy._src.math.sparse.utils import csr_to_coo -from brainpy.errors import PackageMissingError __all__ = [ 'csrmv' @@ -70,442 +63,5 @@ def csrmv( The array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError( - 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError( - 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # if the shape of indices is (0,), then we return a zero vector - if indices.shape[0] == 0: - return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - - return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] - - -def raw_csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -): - if ti is None: - raise PackageMissingError.by_purpose(name='taichi==1.7.0', purpose='customized operators') - - if transpose: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_bool_homo_p - else: - prim = _event_csrmv_transpose_bool_heter_p - else: - return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - else: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_bool_homo_p - else: - prim = _event_csrmv_bool_heter_p - else: - return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - - # computing - return prim(data, - indices, - indptr, - events, - outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -if ti is not None: - - # ------------- - # CPU operators - # ------------- - - # 1. The benchmarking shows that the performance of the following transpose - # kernels is maximized when using serialized mode - # 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable - # arguments, we have to define each kernel separately when the - # non-differentiable/non-jittable arguments are different. - - @ti.kernel - def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - - @ti.kernel - def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - - @ti.kernel - def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - - @ti.kernel - def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - - @ti.kernel - def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - - @ti.kernel - def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - - @ti.kernel - def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += value - out[row_i] = r - - - @ti.kernel - def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += values[j] - out[row_i] = r - - - # ------------- - # GPU operators - # ------------- - - # 1. GPU kernels are different from the CPU ones, since the GPU kernels need - # to use warp-level parallelism to achieve the best performance. - - @ti.kernel - def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - - @ti.kernel - def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - - # TODO - # It is important to note that the following warp-based kernels - # should be improved, since the atomic_add for each thread is not - # very efficient. Instead, the warp-level reduction primitive - # should be used. - # see ``warp_reduce_sum()`` function in tifunc.py. - # However, currently Taichi does not support general warp-level primitives. - - @ti.kernel - def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - @ti.kernel - def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - @ti.kernel - def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - - @ti.kernel - def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - - @ti.kernel - def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - @ti.kernel - def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) - - - def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) - - - def _event_csr_matvec_transpose_taichi( - ct, values, indices, indptr, events, *, outs, transpose, shape - ): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events - - - def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) - prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) - return prim - - - # transpose bool homo - _event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, - _event_csr_matvec_transpose_bool_homo_gpu) - - # transpose homo - _event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, - _event_csr_matvec_transpose_homo_gpu) - - # not transpose bool homo - _event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, - _event_csr_matvec_bool_homo_gpu) - - # not transpose homo - _event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, - _event_csr_matvec_homo_gpu) - - # transpose bool heter - _event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, - _event_csr_matvec_transpose_bool_heter_gpu) - - # transpose heter - _event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, - _event_csr_matvec_transpose_heter_gpu) - - # not transpose bool heter - _event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, - _event_csr_matvec_bool_heter_gpu) - - # not transpose heter - _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, - _event_csr_matvec_heter_gpu) + return bt_event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/jitconn/event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py index a22aac75..174deb08 100644 --- a/brainpy/_src/math/jitconn/event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -3,25 +3,12 @@ from typing import Tuple, Optional import jax -import numpy as np -from jax import numpy as jnp +from braintaichi import jitc_event_mv_prob_homo, jitc_event_mv_prob_uniform, jitc_event_mv_prob_normal from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.jitconn.matvec import (mv_prob_homo, mv_prob_uniform, - mv_prob_normal, - _general_checking, - raw_mv_prob_homo, - raw_mv_prob_uniform, - raw_mv_prob_normal, - _mv_prob_homo_transpose, - _mv_prob_uniform_transpose, - _mv_prob_normal_transpose, - _reverse) -from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import XLACustomOp -from brainpy.errors import PackageMissingError + mv_prob_normal) ti = import_taichi(error_if_not_found=False) @@ -42,23 +29,10 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - events = as_jax(events) - weight = as_jax(weight) - if jnp.ndim(weight) < 1: - weight = jnp.expand_dims(weight, axis=0) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_homo(events, weight, conn_len, seed, + return jitc_event_mv_prob_homo(events, weight, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel)[0] + outdim_parallel=outdim_parallel) event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ @@ -75,22 +49,8 @@ def event_mv_prob_uniform( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - events = as_jax(events) - if isinstance(w_low, float): w_low = as_jax(w_low) - if isinstance(w_high, float): w_high = as_jax(w_high) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ @@ -107,1054 +67,8 @@ def event_mv_prob_normal( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - events = as_jax(events) - if isinstance(w_mu, float): w_mu = as_jax(w_mu) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ - -if ti is not None: - from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) - - - # ------------- - # CPU function - # ------------- - # For each non-zero event value, it generates a random key using a - # function lfsr88_key and then uses this key to compute random integers - # and update the out array based on the computed indices and weight. - # - # The function is likely designed to be parallelized. - - @ti.kernel - def _event_mv_prob_homo_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_homo_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - if events[i_col]: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - # ------------- - # GPU function - # ------------- - # Contrary to the CPU functions, for each column, - # this function will 32 threads (one warp) to make - # the just-in-time random generation parallelized. - - @ti.kernel - def _event_mv_prob_homo_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_homo_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _reverse(shape): - return shape[::-1] - - - # ------------- - # CPU function - # ------------- - # For each non-zero event value, it generates a random key using a - # function lfsr88_key and then uses this key to compute random integers - # and update the out array based on the computed indices and weight. - # - # The function is likely designed to be parallelized. - - @ti.kernel - def _event_mv_prob_homo_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_homo_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - if events[i_col] != 0.: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - - # ------------- - # GPU function - # ------------- - # Contrary to the CPU functions, for each column, - # this function will 32 threads (one warp) to make - # the just-in-time random generation parallelized. - - @ti.kernel - def _event_mv_prob_homo_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_homo_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _event_mv_prob_homo_jvp_events( - evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(evt_dot, weight, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_homo_jvp_weight( - w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(events, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - - def raw_event_mv_prob_homo( - events: jax.Array, - weight: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_outdim_parallel_bool_p - else: - prim = _event_mv_prob_homo_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_bool_p - else: - prim = _event_mv_prob_homo_p - - return prim(events, - weight, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_homo_jvp_events, - _event_mv_prob_homo_jvp_weight, - None, - None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - - # outdim_parallel = True, events.dtype = jnp.bool_ - _event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu - ) - - # outdim_parallel = False, events.dtype = jnp.bool_ - _event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_bool_cpu, - gpu_kernel=_event_mv_prob_homo_bool_gpu - ) - - # outdim_parallel = True, events.dtype != jnp.bool_ - _event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu - ) - - # outdim_parallel = False, events.dtype != jnp.bool_ - _event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_cpu, - gpu_kernel=_event_mv_prob_homo_gpu - ) - - - @ti.kernel - def _event_mv_prob_uniform_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_uniform_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _event_mv_prob_uniform_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_uniform_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - @ti.kernel - def _event_mv_prob_uniform_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_uniform_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - - @ti.kernel - def _event_mv_prob_uniform_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_uniform_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _event_mv_prob_uniform_jvp_events( - evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_uniform_jvp_w_low( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_uniform_jvp_w_high( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def raw_event_mv_prob_uniform( - events: jax.Array, - w_low: jax.Array, # vector with size 1 - w_high: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_outdim_parallel_bool_p - else: - prim = _event_mv_prob_uniform_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_bool_p - else: - prim = _event_mv_prob_uniform_p - - return prim(events, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_uniform_jvp_events, - _event_mv_prob_uniform_jvp_w_low, - _event_mv_prob_uniform_jvp_w_high, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - - # outdim_parallel = True, events.dtype = jnp.bool_ - _event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu - ) - - # outdim_parallel = False, events.dtype = jnp.bool_ - _event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_bool_gpu - ) - - # outdim_parallel = True, events.dtype != jnp.bool_ - _event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu - ) - - # outdim_parallel = False, events.dtype != jnp.bool_ - _event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_cpu, - gpu_kernel=_event_mv_prob_uniform_gpu - ) - - - @ti.kernel - def _event_mv_prob_normal_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_normal_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _event_mv_prob_normal_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_normal_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - @ti.kernel - def _event_mv_prob_normal_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_normal_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _event_mv_prob_normal_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_normal_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _event_mv_prob_normal_jvp_events( - evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_normal_jvp_w_mu( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_normal_jvp_w_sigma( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def raw_event_mv_prob_normal( - events: jax.Array, - w_mu: jax.Array, # vector with size 1 - w_sigma: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_outdim_parallel_bool_p - else: - prim = _event_mv_prob_normal_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_bool_p - else: - prim = _event_mv_prob_normal_p - - return prim(events, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_normal_jvp_events, - _event_mv_prob_normal_jvp_w_mu, - _event_mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - - # outdim_parallel = True, events.dtype = jnp.bool_ - _event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu - ) - - # outdim_parallel = False, events.dtype = jnp.bool_ - _event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_bool_cpu, - gpu_kernel=_event_mv_prob_normal_bool_gpu - ) - - # outdim_parallel = True, events.dtype != jnp.bool_ - _event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu - ) - - # outdim_parallel = False, events.dtype != jnp.bool_ - _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_cpu, - gpu_kernel=_event_mv_prob_normal_gpu - ) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 296a7994..648ef78e 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -4,14 +4,15 @@ import jax import numpy as np +from braintaichi import jitc_mv_prob_homo, jitc_mv_prob_uniform +from jax import numpy as jnp + from brainpy._src.dependency_check import import_taichi from brainpy._src.math import defaults from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array, _get_dtype +from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError -from jax import numpy as jnp -from jax.interpreters import ad ti = import_taichi(error_if_not_found=False) @@ -83,22 +84,8 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - vector = as_jax(vector) - if isinstance(weight, float): - weight = as_jax(weight, dtype=vector.dtype) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.asarray(seed, dtype=jnp.uint32) - seed = jnp.atleast_1d(seed) - return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def mv_prob_uniform( @@ -162,22 +149,8 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - vector = as_jax(vector) - if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) - if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def mv_prob_normal( @@ -241,22 +214,8 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - vector = as_jax(vector) - if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return jitc_mv_prob_uniform(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def get_homo_weight_matrix( @@ -421,91 +380,6 @@ def get_normal_weight_matrix( return r -def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p - else: - prim = _mv_prob_uniform_p - - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p - - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - def raw_get_homo_weight_matrix( conn_len: jax.Array, seed: jax.Array, @@ -577,584 +451,10 @@ def raw_get_normal_weight_matrix( outdim_parallel=outdim_parallel) -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - - -def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed - else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed - else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _mv_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_low, w_high, clen, seed - else: - dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_low, w_high, clen, seed - else: - assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' - assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed - else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _reverse(shape): - return shape[::-1] - if ti is not None: from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) - - @ti.kernel - def _mv_prob_homo_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - v = vector[i_col] * weight0 - while i_row < num_row: - out[i_row] += v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_homo_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r * weight0 - - - @ti.kernel - def _mv_prob_homo_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_homo_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += weight0 * r # TODO: warp-level reduction - - - def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - - # outdim_parallel = True - _mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) - - # outdim_parallel = False - _mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, - gpu_kernel=_mv_prob_homo_gpu) - - - @ti.kernel - def _mv_prob_uniform_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_uniform_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _mv_prob_uniform_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_uniform_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_uniform_jvp_vector, - _mv_prob_uniform_jvp_wlow, - _mv_prob_uniform_jvp_whigh, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - - # outdim_parallel = True - _mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu - ) - - # outdim_parallel = False - _mv_prob_uniform_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_cpu, - gpu_kernel=_mv_prob_uniform_gpu - ) - - - @ti.kernel - def _mv_prob_normal_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_normal_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _mv_prob_normal_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_normal_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_normal_jvp_vector, - _mv_prob_normal_jvp_w_mu, - _mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - - # outdim_parallel = True - _mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_mv_prob_normal_outdim_parallel_gpu - ) - - # outdim_parallel = False - _mv_prob_normal_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_cpu, - gpu_kernel=_mv_prob_normal_gpu - ) - - @ti.kernel def _get_connect_matrix( clen: ti.types.ndarray(), diff --git a/brainpy/_src/math/sparse/coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py index 2885d946..439a99ae 100644 --- a/brainpy/_src/math/sparse/coo_mv.py +++ b/brainpy/_src/math/sparse/coo_mv.py @@ -1,18 +1,12 @@ # -*- coding: utf-8 -*- -import warnings -from functools import partial from typing import Union, Tuple -import numpy as np -from jax import core, numpy as jnp, dtypes, default_backend -from jax.interpreters import ad, mlir -from jaxlib import gpu_sparse +from braintaichi import coomv as bt_coomv +from jax import numpy as jnp -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import register_general_batching __all__ = [ 'coomv', @@ -66,135 +60,14 @@ def coomv( the matrix vector product. """ - data = jnp.atleast_1d(as_jax(data)) - row = as_jax(row) - col = as_jax(col) - vector = as_jax(vector) - - if method == 'cusparse': - if default_backend() != 'cpu': - if data.shape[0] == 1: - data = jnp.ones(row.shape, dtype=data.dtype) * data - if row.dtype in [jnp.uint32, jnp.uint64]: - row = jnp.asarray(row, dtype=dtypes.canonicalize_dtype(jnp.int64)) - if col.dtype in [jnp.uint32, jnp.uint64]: - col = jnp.asarray(col, dtype=dtypes.canonicalize_dtype(jnp.int64)) - return _coomv_cusparse_p.bind(data, - row, - col, - vector, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - else: - raise ValueError - - -# -------------------------------------------------------------------- -# cusparse_coo_matvec - - -def _coomv_impl(data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - v = jnp.asarray(v) - if transpose: - row, col = col, row - out_shape = shape[1] if transpose else shape[0] - dv = data * v[col] - return jnp.zeros(out_shape, dv.dtype).at[row].add(dv) - - -def _coomv_abstract_eval(data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - assert data.shape == row.shape == col.shape - assert data.dtype == v.dtype - assert row.dtype == col.dtype - assert len(shape) == 2 - assert v.ndim == 1 - assert v.shape[0] == (shape[0] if transpose else shape[1]) - out_shape = shape[1] if transpose else shape[0] - return core.ShapedArray((out_shape,), data.dtype) - - -_coo_matvec_lowering = mlir.lower_fun(_coomv_impl, multiple_results=False) - - -def _coomv_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *, - shape, rows_sorted, cols_sorted, transpose): - data_aval, row_aval, _, x_aval = ctx.avals_in - dtype = data_aval.dtype - if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: - warnings.warn(f"cusparse_coo_matvec cusparse/hipsparse lowering not available for dtype={dtype}. " - "Falling back to default implementation.", UserWarning) - return _coo_matvec_lowering(ctx, data, row, col, v, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - if rows_sorted: - shape = shape - elif cols_sorted: - row, col = col, row - transpose = not transpose - shape = shape[::-1] - else: - warnings.warn("cusparse_coo_matvec GPU lowering requires matrices with sorted rows or sorted cols. " - "To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling " - "back to the default implementation.", UserWarning) - return _coo_matvec_lowering(ctx, data, row, col, v, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - return [coo_matvec_mhlo(data, row, col, v, - shape=shape, - transpose=transpose, - index_dtype=row_aval.dtype, - data_dtype=dtype, - x_dtype=x_aval.dtype)] - - -def _coomv_jvp_mat(data_dot, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - return _coomv_cusparse_p.bind(data_dot, row, col, v, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - -def _coomv_jvp_vec(v_dot, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - return _coomv_cusparse_p.bind(data, row, col, v_dot, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - -def _coomv_transpose(ct, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - assert not ad.is_undefined_primal(row) - assert not ad.is_undefined_primal(col) - - if ad.is_undefined_primal(v): - return data, row, col, _coomv_cusparse_p.bind(data, row, col, ct, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=not transpose) - else: - return ct[row] * v[col], row, col, v - - -_coomv_cusparse_p = core.Primitive('cusparse_coo_matvec') -_coomv_cusparse_p.def_abstract_eval(_coomv_abstract_eval) -_coomv_cusparse_p.def_impl(_coomv_impl) -ad.defjvp(_coomv_cusparse_p, _coomv_jvp_mat, None, None, _coomv_jvp_vec) -ad.primitive_transposes[_coomv_cusparse_p] = _coomv_transpose -mlir.register_lowering(_coomv_cusparse_p, _coo_matvec_lowering) -mlir.register_lowering(_coomv_cusparse_p, - partial(_coomv_gpu_lowering, gpu_sparse.cuda_coo_matvec), - platform='cuda') -register_general_batching(_coomv_cusparse_p) - - + return bt_coomv( + data=data, + row=row, + col=col, + vector=vector, + shape=shape, + rows_sorted=rows_sorted, + cols_sorted=cols_sorted, + transpose=transpose, + method=method + ) diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index 47c24fa4..cfe36a3e 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -3,17 +3,11 @@ from typing import Union, Tuple -import jax -import numpy as np +from braintaichi import csrmm as bt_csrmm from jax import numpy as jnp -from jax.experimental.sparse import csr -from jax.interpreters import ad from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) -from brainpy.errors import PackageMissingError ti = import_taichi(error_if_not_found=False) @@ -48,180 +42,4 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product. """ - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] - - -def raw_csrmm_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -): - assert len(shape) == 2 - - indices = as_jax(indices) - indptr = as_jax(indptr) - matrix = as_jax(matrix) - data = jnp.atleast_1d(data) - - if matrix.dtype == jnp.bool_: - matrix = as_jax(matrix, dtype=data.dtype) - - if data.dtype != matrix.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {matrix.dtype}.') - assert matrix.ndim == 2 - - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - assert indptr.shape[0] == shape[0] + 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - - assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - - if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype), ] - - # homo -> taichi, - # heter -> cusparse - if data.shape[0] != 1: - return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] - else: - if ti is None: - raise PackageMissingError.by_purpose('taichi', 'customzied sparse matrix multiplication') - if transpose: - prim = _csr_matmat_transpose_homo_p - else: - prim = _csr_matmat_homo_p - r = prim(indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)], - transpose=transpose, - shape=shape) - return [r[0] * data] - - -# taichi kernels -if ti is not None: - # @ti.kernel - # def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), - # col_indices: ti.types.ndarray(ndim=1), - # row_ptr: ti.types.ndarray(ndim=1), - # matrix: ti.types.ndarray(ndim=2), - # out: ti.types.ndarray(ndim=2)): - # for row_i in range(row_ptr.shape[0] - 1): - # for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - # col = col_indices[i] - # for j in range(out.shape[1]): - # out[col, j] += values[row_i] * matrix[row_i, j] - # - # @ti.kernel - # def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), - # col_indices: ti.types.ndarray(ndim=1), - # row_ptr: ti.types.ndarray(ndim=1), - # matrix: ti.types.ndarray(ndim=2), - # out: ti.types.ndarray(ndim=2)): - # for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - # r = 0. - # for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - # r += values[j] * matrix[col_indices[j], col_k] - # out[row_i, col_k] = r - # - # # transpose heter - # _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter, - # gpu_kernel=_csr_matmat_transpose_heter) - # - # # no transpose heter - # _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, - # gpu_kernel=_csr_matmat_heter) - - @ti.kernel - def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - # matrix: (k, n) - # sparse matrix: (m, k) - n = out.shape[1] - m = row_ptr.shape[0] - 1 - for j in range(n): # parallize along the n dimension - for row_i in range(m): # loop along the m dimension - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[i], j] += matrix[row_i, j] - - - @ti.kernel - def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - m = row_ptr.shape[0] - 1 - n = matrix.shape[1] - for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[i], j] += matrix[row_i, j] - - - @ti.kernel - def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - # matrix: (k, n) - # sparse matrix: (m, k) - m, n = out.shape - for row_i, col_k in ti.ndrange(m, n): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - - def _csr_matmat_jvp_matrix(mat_dot, col_indices, row_ptr, matrix, *, outs, transpose, shape): - if transpose: - return _csr_matmat_transpose_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs) - else: - return _csr_matmat_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs) - - - def _csr_matmat_transpose( - ct, col_indices, row_ptr, matrix, *, outs, transpose, shape, - ): - if ad.is_undefined_primal(col_indices) or ad.is_undefined_primal(row_ptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - assert ad.is_undefined_primal(matrix) - ct_matrix = raw_csrmm_taichi(jnp.ones(1), col_indices, row_ptr, ct[0], - shape=shape, - transpose=not transpose) - return col_indices, row_ptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix[0]) - - - def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(None, None, _csr_matmat_jvp_matrix) - prim.def_transpose_rule(_csr_matmat_transpose) - return prim - - - # transpose homo - _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, - gpu_kernel=_csr_matmat_transpose_homo_gpu) - - # no transpose homo - _csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) - - # heter CUSPARSE - _csr_matmat_cusparse_p = csr.csr_matmat_p - register_general_batching(_csr_matmat_cusparse_p) + return bt_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py index 6eaf6b79..78aeb5a8 100644 --- a/brainpy/_src/math/sparse/csr_mv.py +++ b/brainpy/_src/math/sparse/csr_mv.py @@ -3,18 +3,11 @@ from typing import Union, Tuple -import jax +from braintaichi import csrmv as bt_csrmv from jax import numpy as jnp -from jax.experimental.sparse import csr -from jax.interpreters import ad -import brainpy.math as bm from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) -from brainpy._src.math.sparse.utils import csr_to_coo -from brainpy.errors import PackageMissingError ti = import_taichi(error_if_not_found=False) @@ -70,256 +63,5 @@ def csrmv( the matrix vector product. """ - data = jnp.atleast_1d(as_jax(data)) - indices = as_jax(indices) - indptr = as_jax(indptr) - vector = as_jax(vector) + return bt_csrmv(data, indices, indptr, vector, shape=shape, transpose=transpose) - if vector.dtype == jnp.bool_: - vector = as_jax(vector, dtype=data.dtype) - - if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]: - raise TypeError('Only support float16, float32 or float64 type. ' - f'But we got {data.dtype}.') - if data.dtype != vector.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {vector.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - # if the shape of indices is (0,), then we return a zero vector - if indices.shape[0] == 0: - return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - - return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] - - -def raw_csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -): - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - out_shape = shape[1] if transpose else shape[0] - if data.shape[0] != 1: - if bm.get_platform() == 'gpu': - return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)] - else: - if transpose: - prim = _csr_matvec_transpose_heter_p - else: - prim = _csr_matvec_heter_p - else: - if transpose: - prim = _csr_matvec_transpose_homo_p - else: - prim = _csr_matvec_homo_p - - return prim(data, - indices, - indptr, - vector, - outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -if ti is not None: - - # ------------- - # CPU operators - # ------------- - @ti.kernel - def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += value * vector[row_i] - - - @ti.kernel - def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += vector[row_i] * values[j] - - - @ti.kernel - def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += vector[col_indices[j]] - out[row_i] = r * value - - - @ti.kernel - def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - out[row_i] = r - - - # ------------- - # GPU operators - # ------------- - - @ti.kernel - def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += value * vector[row_i] - j += 32 - - - @ti.kernel - def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += vector[col_indices[j]] - j += 32 - out[row_i] += value * r - - - @ti.kernel - def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += values[j] * vector[row_i] - j += 32 - - - @ti.kernel - def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += values[j] * vector[col_indices[j]] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) - - - def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) - - - def _sparse_csr_matvec_transpose( - ct, data, indices, indptr, vector, *, outs, transpose, shape, - ): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(vector): - ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] - ct_data = jnp.inner(ct[0], ct_data) - else: - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] - - return ct_data, indices, indptr, vector - - - def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) - prim.def_transpose_rule(_sparse_csr_matvec_transpose) - return prim - - - # transpose homo - _csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) - - # no transpose homo - _csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, - gpu_kernel=_sparse_csr_matvec_homo_gpu) - - # transpose heter - _csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) - - # no transpose heter - _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) - - # heter cusparse - _csr_matvec_cusparse_p = csr.csr_matvec_p - register_general_batching(_csr_matvec_cusparse_p) diff --git a/requirements-dev.txt b/requirements-dev.txt index 754073f4..1ad33b04 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,7 @@ matplotlib msgpack tqdm pathos -taichi==1.7.0 +braintaichi numba braincore braintools diff --git a/requirements-doc.txt b/requirements-doc.txt index 7b1ebb6e..3c4a3ae6 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -5,7 +5,7 @@ matplotlib numpy scipy numba -taichi==1.7.0 +braintaichi # document requirements pandoc diff --git a/setup.py b/setup.py index 55f948e4..84ac38c1 100644 --- a/setup.py +++ b/setup.py @@ -68,9 +68,9 @@ 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', ], extras_require={ - 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'], - 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib', 'numba', 'taichi==1.7.0'], - 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib', 'numba', 'taichi==1.7.0'], + 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'braintaichi'], + 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib', 'numba', 'braintaichi'], + 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib', 'numba', 'braintaichi'], 'tpu': ['jaxlib[tpu]', 'numba',], 'cpu_mini': ['jaxlib>=0.4.13'], 'cuda11_mini': ['jaxlib[cuda11_pip]'],