diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 05a7c79c..4c020231 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -6,6 +6,8 @@ __all__ = [ 'import_taichi', 'raise_taichi_not_found', + 'import_braintaichi', + 'raise_braintaichi_not_found', 'import_numba', 'raise_numba_not_found', 'import_cupy', @@ -16,10 +18,11 @@ ] _minimal_brainpylib_version = '0.2.6' -_minimal_taichi_version = (1, 7, 0) +_minimal_taichi_version = (1, 7, 2) numba = None taichi = None +braintaichi = None cupy = None cupy_jit = None brainpylib_cpu_ops = None @@ -33,6 +36,10 @@ cupy_install_info = ('We need cupy. Please install cupy by pip . \n' 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n' 'For CUDA v12.x > pip install cupy-cuda12x\n') +braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n' + '> pip install braintaichi -U') + + os.environ["TI_LOG_LEVEL"] = "error" @@ -69,6 +76,26 @@ def import_taichi(error_if_not_found=True): def raise_taichi_not_found(*args, **kwargs): raise ModuleNotFoundError(taichi_install_info) +def import_braintaichi(error_if_not_found=True): + """Internal API to import braintaichi. + + If braintaichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global braintaichi + if braintaichi is None: + try: + import braintaichi as braintaichi + except ModuleNotFoundError: + if error_if_not_found: + raise_braintaichi_not_found() + else: + return None + return braintaichi + +def raise_braintaichi_not_found(): + raise ModuleNotFoundError(braintaichi_install_info) + def import_numba(error_if_not_found=True): """ diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index 8e09f95b..e517e556 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -11,7 +11,7 @@ from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, import_braintaichi from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP from brainpy.check import is_initializer @@ -20,6 +20,7 @@ from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding +bti = import_braintaichi(error_if_not_found=False) ti = import_taichi(error_if_not_found=False) __all__ = [ @@ -238,7 +239,7 @@ def update(self, x): return x -if ti is not None: +if ti is not None and bti is not None: # @numba.njit(nogil=True, fastmath=True, parallel=False) # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): @@ -273,7 +274,7 @@ def _dense_on_post( out_w[i, j] = old_w[i, j] - dense_on_post_prim = bm.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post) + dense_on_post_prim = bti.XLACustomOp(cpu_kernel=_dense_on_post, gpu_kernel=_dense_on_post) # @numba.njit(nogil=True, fastmath=True, parallel=False) @@ -309,7 +310,7 @@ def _dense_on_pre( out_w[i, j] = old_w[i, j] - dense_on_pre_prim = bm.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre) + dense_on_pre_prim = bti.XLACustomOp(cpu_kernel=_dense_on_pre, gpu_kernel=_dense_on_pre) else: dense_on_pre_prim = None @@ -326,6 +327,12 @@ def dense_on_pre(weight, spike, trace, w_min, w_max): w_max = np.inf w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) + + weight = bm.as_jax(weight) + spike = bm.as_jax(spike) + trace = bm.as_jax(trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) return dense_on_pre_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] @@ -340,6 +347,12 @@ def dense_on_post(weight, spike, trace, w_min, w_max): w_max = np.inf w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) + + weight = bm.as_jax(weight) + spike = bm.as_jax(spike) + trace = bm.as_jax(trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) return dense_on_post_prim(weight, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(weight.shape, weight.dtype)])[0] @@ -735,7 +748,7 @@ def _csr_on_pre_update( out_w[i_syn] = old_w[i_syn] - csr_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update) + csr_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_csr_on_pre_update, gpu_kernel=_csr_on_pre_update) @ti.kernel @@ -759,7 +772,7 @@ def _coo_on_pre_update( out_w[i_syn] = old_w[i_syn] - coo_on_pre_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update) + coo_on_pre_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_pre_update, gpu_kernel=_coo_on_pre_update) @ti.kernel @@ -783,7 +796,7 @@ def _coo_on_post_update( out_w[i_syn] = old_w[i_syn] - coo_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update) + coo_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_coo_on_post_update, gpu_kernel=_coo_on_post_update) # @numba.njit(nogil=True, fastmath=True, parallel=False) @@ -824,7 +837,7 @@ def _csc_on_post_update( out_w[i_syn] = old_w[i_syn] - csc_on_post_update_prim = bm.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update) + csc_on_post_update_prim = bti.XLACustomOp(cpu_kernel=_csc_on_post_update, gpu_kernel=_csc_on_post_update) else: @@ -843,6 +856,14 @@ def csr_on_pre_update(w, indices, indptr, spike, trace, w_min=None, w_max=None): w_max = np.inf w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) + + w = bm.as_jax(w) + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + spike = bm.as_jax(spike) + trace = bm.as_jax(trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) return csr_on_pre_update_prim(w, indices, indptr, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] @@ -857,6 +878,15 @@ def coo_on_pre_update(w, pre_ids, post_ids, spike, trace, w_min=None, w_max=None w_max = np.inf w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) + + w = bm.as_jax(w) + pre_ids = bm.as_jax(pre_ids) + post_ids = bm.as_jax(post_ids) + spike = bm.as_jax(spike) + trace = bm.as_jax(trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) + return coo_on_pre_update_prim(w, pre_ids, post_ids, spike, trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] @@ -871,6 +901,15 @@ def csc_on_post_update(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min= w_max = np.inf w_min = jnp.atleast_1d(w_min) w_max = jnp.atleast_1d(w_max) + + w = bm.as_jax(w) + post_ids = bm.as_jax(post_ids) + indptr = bm.as_jax(indptr) + w_ids = bm.as_jax(w_ids) + post_spike = bm.as_jax(post_spike) + pre_trace = bm.as_jax(pre_trace) + w_min = bm.as_jax(w_min) + w_max = bm.as_jax(w_max) return csc_on_post_update_prim(w, post_ids, indptr, w_ids, post_spike, pre_trace, w_min, w_max, outs=[jax.ShapeDtypeStruct(w.shape, w.dtype)])[0] diff --git a/brainpy/_src/dnn/tests/test_linear.py b/brainpy/_src/dnn/tests/test_linear.py index 9f011cb8..2fd7df2d 100644 --- a/brainpy/_src/dnn/tests/test_linear.py +++ b/brainpy/_src/dnn/tests/test_linear.py @@ -1,14 +1,11 @@ import pytest from absl.testing import absltest from absl.testing import parameterized +import jax.numpy as jnp import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) class TestLinear(parameterized.TestCase): @@ -104,11 +101,11 @@ def test_CSRLinear(self, conn): bm.random.seed() f = bp.dnn.CSRLinear(conn, weight=bp.init.Normal()) x = bm.random.random((16, 100)) - y = f(x) + y = f(jnp.asarray(x)) self.assertTrue(y.shape == (16, 100)) x = bm.random.random((100,)) - y = f(x) + y = f(jnp.asarray(x)) self.assertTrue(y.shape == (100,)) bm.clear_buffer_memory() @@ -123,10 +120,10 @@ def test_EventCSRLinear(self, conn): bm.random.seed() f = bp.layers.EventCSRLinear(conn, weight=bp.init.Normal()) x = bm.random.random((16, 100)) - y = f(x) + y = f(jnp.asarray(x)) self.assertTrue(y.shape == (16, 100)) x = bm.random.random((100,)) - y = f(x) + y = f(jnp.asarray(x)) self.assertTrue(y.shape == (100,)) bm.clear_buffer_memory() diff --git a/brainpy/_src/dnn/tests/test_mode.py b/brainpy/_src/dnn/tests/test_mode.py index 10e9eeda..eb87c201 100644 --- a/brainpy/_src/dnn/tests/test_mode.py +++ b/brainpy/_src/dnn/tests/test_mode.py @@ -4,10 +4,6 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) class Test_Conv(parameterized.TestCase): diff --git a/brainpy/_src/dyn/projections/tests/test_STDP.py b/brainpy/_src/dyn/projections/tests/test_STDP.py index 18d9d9dc..8fab94de 100644 --- a/brainpy/_src/dyn/projections/tests/test_STDP.py +++ b/brainpy/_src/dyn/projections/tests/test_STDP.py @@ -6,10 +6,6 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) bm.set_platform('cpu') diff --git a/brainpy/_src/dyn/projections/tests/test_aligns.py b/brainpy/_src/dyn/projections/tests/test_aligns.py index eec2c945..8bf2c150 100644 --- a/brainpy/_src/dyn/projections/tests/test_aligns.py +++ b/brainpy/_src/dyn/projections/tests/test_aligns.py @@ -5,10 +5,6 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) neu_pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., V_initializer=bp.init.Normal(-55., 2.)) diff --git a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py index d068f207..0b371bcb 100644 --- a/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_abstract_synapses.py @@ -7,10 +7,6 @@ import brainpy as bp import brainpy.math as bm from brainpy._src.dynold.synapses import abstract_models -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) class Test_Abstract_Synapse(parameterized.TestCase): diff --git a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py index 01a31526..b48cb5b7 100644 --- a/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py +++ b/brainpy/_src/dynold/synapses/tests/test_biological_synapses.py @@ -6,10 +6,6 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) biological_models = [ bp.synapses.AMPA, diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index de559de5..01159883 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -44,7 +44,7 @@ from .compat_numpy import * from .compat_tensorflow import * from .others import * -from . import random, linalg, fft, tifunc +from . import random, linalg, fft # operators from .op_register import * diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index 33677691..b78afad7 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -3,21 +3,13 @@ from typing import Union, Tuple -import jax -import numpy as np + from jax import numpy as jnp -from jax.interpreters import ad -from jax.experimental.sparse import csr -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) -from brainpy._src.math.sparse.csr_mm import raw_csrmm_taichi as normal_csrmm -from brainpy._src.math.sparse.utils import csr_to_coo -from brainpy._src.math.defaults import float_ +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found -ti = import_taichi() +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'csrmm', @@ -49,262 +41,7 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product product. """ - return raw_event_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] - - -def raw_event_csrmm_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -): - assert len(shape) == 2 - - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - - indices = as_jax(indices) - indptr = as_jax(indptr) - matrix = as_jax(matrix) - - assert data.ndim == indices.ndim == indptr.ndim == 1 - assert matrix.ndim == 2 - assert indptr.shape[0] == shape[0] + 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - # if the shape of indices is (0,), then we return a zero matrix - if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype), ] - - assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - - # homo -> taichi - # heter -> cusparse - if data.shape[0] != 1: - if matrix.dtype == jnp.bool_: - # change dtype to float - matrix = matrix.astype(float_) - return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] - else: - if transpose: - if matrix.dtype == jnp.bool_: - prim = _event_csr_matmat_transpose_homo_p - else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) - else: - if matrix.dtype == jnp.bool_: - prim = _event_csr_matmat_bool_homo_p - else: - return normal_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) - return prim(data, - indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -# taichi kernels - -@ti.kernel -def _event_csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += values[j] * matrix[row_j, col_i] - - -@ti.kernel -def _event_csr_matmat_transpose_bool_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += values[j] * matrix[row_j, col_i] - - -@ti.kernel -def _event_csr_matmat_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += values[row_j] * matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - -@ti.kernel -def _event_csr_matmat_bool_heter(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += values[row_j] * matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - -@ti.kernel -def _event_csr_matmat_transpose_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i] != 0.: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += value * matrix[row_j, col_i] - - -@ti.kernel -def _event_csr_matmat_transpose_bool_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for col_i, row_k in ti.ndrange(out.shape[1], out.shape[0]): - for row_j in range(matrix.shape[0]): - if matrix[row_j, col_i]: - for j in range(row_ptr[row_j], row_ptr[row_j + 1]): - if col_indices[j] == row_k: - out[row_k, col_i] += value * matrix[row_j, col_i] - - -@ti.kernel -def _event_csr_matmat_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k] != 0.: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value - - -@ti.kernel -def _event_csr_matmat_bool_homo(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - value = values[0] - for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - if matrix[col_indices[row_j], col_k]: - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r * value - - -def _event_csr_matmat_jvp_values(val_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return normal_csrmm(val_dot, col_indices, row_ptr, matrix, shape=shape, transpose=transpose) - - -def _event_csr_matmat_jvp_matrix(mat_dot, values, col_indices, row_ptr, matrix, *, outs, transpose, shape): - return normal_csrmm(values, col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose) - - -def _event_csr_matmat_transpose( - ct, data, indices, indptr, matrix, *, outs, transpose, shape, -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(matrix): - ct_matrix = raw_event_csrmm_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = \ - raw_event_csrmm_taichi(jnp.ones(1), indices, indptr, matrix, shape=shape, transpose=transpose)[0] - ct_data = jnp.sum(ct[0] * ct_data) - else: # heter - matrix = jnp.asarray(matrix) - row, col = csr_to_coo(indices, indptr) - ct_data = (ct[0][row] * matrix[col]).sum(1) - return ct_data, indices, indptr, matrix - - -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matmat_jvp_values, None, None, _event_csr_matmat_jvp_matrix) - prim.def_transpose_rule(_event_csr_matmat_transpose) - return prim - - -# transpose heter -_event_csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_heter, - gpu_kernel=_event_csr_matmat_transpose_heter) - -# no transpose heter -_event_csr_matmat_heter_p = _define_op(cpu_kernel=_event_csr_matmat_heter, - gpu_kernel=_event_csr_matmat_heter) - -# transpose homo -_event_csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_homo, - gpu_kernel=_event_csr_matmat_transpose_homo) - -# no transpose homo -_event_csr_matmat_homo_p = _define_op(cpu_kernel=_event_csr_matmat_homo, - gpu_kernel=_event_csr_matmat_homo) - -# bool transpose heter -_event_csr_matmat_transpose_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_heter, - gpu_kernel=_event_csr_matmat_transpose_bool_heter) - -# bool no transpose heter -_event_csr_matmat_bool_heter_p = _define_op(cpu_kernel=_event_csr_matmat_bool_heter, - gpu_kernel=_event_csr_matmat_bool_heter) - -# bool transpose homo -_event_csr_matmat_transpose_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_transpose_bool_homo, - gpu_kernel=_event_csr_matmat_transpose_bool_homo) - -# bool no transpose homo -_event_csr_matmat_bool_homo_p = _define_op(cpu_kernel=_event_csr_matmat_bool_homo, - gpu_kernel=_event_csr_matmat_bool_homo) + if bti is None: + raise_braintaichi_not_found() -# heter CUSPARSE -_csr_matmat_cusparse_p = csr.csr_matmat_p -register_general_batching(_csr_matmat_cusparse_p) \ No newline at end of file + return bti.event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py index d4478345..3969ee6b 100644 --- a/brainpy/_src/math/event/csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -13,22 +13,16 @@ from typing import Union, Tuple import jax -import jax.numpy as jnp -import numpy as np -from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import XLACustomOp -from brainpy._src.math.sparse.csr_mv import raw_csrmv_taichi as normal_csrmv_taichi -from brainpy._src.math.sparse.utils import csr_to_coo -from brainpy.errors import PackageMissingError +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found + +bti = import_braintaichi(error_if_not_found=False) + __all__ = [ 'csrmv' ] -ti = import_taichi(error_if_not_found=False) def csrmv( @@ -70,442 +64,7 @@ def csrmv( The array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError( - 'indices should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if indptr.dtype not in [jnp.int8, jnp.int16, jnp.int32, jnp.int64, jnp.uint8, jnp.uint16, jnp.uint32, jnp.uint64]: - raise ValueError( - 'indptr should be a 1D vector with int8, int16, int32, int64, uint8, uint16, uint32 or uint64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # if the shape of indices is (0,), then we return a zero vector - if indices.shape[0] == 0: - return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - - return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] - - -def raw_csrmv_taichi( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -): - if ti is None: - raise PackageMissingError.by_purpose(name='taichi==1.7.0', purpose='customized operators') - - if transpose: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_transpose_bool_homo_p - else: - prim = _event_csrmv_transpose_bool_heter_p - else: - return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - else: - if events.dtype == jnp.bool_: - if data.shape[0] == 1: - prim = _event_csrmv_bool_homo_p - else: - prim = _event_csrmv_bool_heter_p - else: - return normal_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) - - # computing - return prim(data, - indices, - indptr, - events, - outs=[jax.ShapeDtypeStruct(shape=(shape[1] if transpose else shape[0],), dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -if ti is not None: - - # ------------- - # CPU operators - # ------------- - - # 1. The benchmarking shows that the performance of the following transpose - # kernels is maximized when using serialized mode - # 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable - # arguments, we have to define each kernel separately when the - # non-differentiable/non-jittable arguments are different. - - @ti.kernel - def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - - @ti.kernel - def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - - @ti.kernel - def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - - @ti.kernel - def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - - @ti.kernel - def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - - @ti.kernel - def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - - @ti.kernel - def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += value - out[row_i] = r - - - @ti.kernel - def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += values[j] - out[row_i] = r - - - # ------------- - # GPU operators - # ------------- - - # 1. GPU kernels are different from the CPU ones, since the GPU kernels need - # to use warp-level parallelism to achieve the best performance. - - @ti.kernel - def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - - @ti.kernel - def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - - # TODO - # It is important to note that the following warp-based kernels - # should be improved, since the atomic_add for each thread is not - # very efficient. Instead, the warp-level reduction primitive - # should be used. - # see ``warp_reduce_sum()`` function in tifunc.py. - # However, currently Taichi does not support general warp-level primitives. - - @ti.kernel - def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - @ti.kernel - def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - @ti.kernel - def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - - @ti.kernel - def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - - @ti.kernel - def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - @ti.kernel - def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) - - - def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) - - - def _event_csr_matvec_transpose_taichi( - ct, values, indices, indptr, events, *, outs, transpose, shape - ): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events - - - def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) - prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) - return prim - - - # transpose bool homo - _event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, - _event_csr_matvec_transpose_bool_homo_gpu) - - # transpose homo - _event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, - _event_csr_matvec_transpose_homo_gpu) - - # not transpose bool homo - _event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, - _event_csr_matvec_bool_homo_gpu) - - # not transpose homo - _event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, - _event_csr_matvec_homo_gpu) - - # transpose bool heter - _event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, - _event_csr_matvec_transpose_bool_heter_gpu) - - # transpose heter - _event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, - _event_csr_matvec_transpose_heter_gpu) - - # not transpose bool heter - _event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, - _event_csr_matvec_bool_heter_gpu) + if bti is None: + raise_braintaichi_not_found() - # not transpose heter - _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, - _event_csr_matvec_heter_gpu) + return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index e82bcdb7..ea830347 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -8,10 +8,6 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) import platform force_test = False # turn on to force test on windows locally diff --git a/brainpy/_src/math/jitconn/event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py index a22aac75..80bba29b 100644 --- a/brainpy/_src/math/jitconn/event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -3,27 +3,13 @@ from typing import Tuple, Optional import jax -import numpy as np -from jax import numpy as jnp -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.jitconn.matvec import (mv_prob_homo, mv_prob_uniform, - mv_prob_normal, - _general_checking, - raw_mv_prob_homo, - raw_mv_prob_uniform, - raw_mv_prob_normal, - _mv_prob_homo_transpose, - _mv_prob_uniform_transpose, - _mv_prob_normal_transpose, - _reverse) -from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import XLACustomOp -from brainpy.errors import PackageMissingError + mv_prob_normal) +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found -ti = import_taichi(error_if_not_found=False) +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'event_mv_prob_homo', @@ -42,23 +28,12 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - events = as_jax(events) - weight = as_jax(weight) - if jnp.ndim(weight) < 1: - weight = jnp.expand_dims(weight, axis=0) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_homo(events, weight, conn_len, seed, + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_homo(events, weight, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel)[0] + outdim_parallel=outdim_parallel) event_mv_prob_homo.__doc__ = mv_prob_homo.__doc__ @@ -75,22 +50,10 @@ def event_mv_prob_uniform( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - events = as_jax(events) - if isinstance(w_low, float): w_low = as_jax(w_low) - if isinstance(w_high, float): w_high = as_jax(w_high) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_uniform(events, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) event_mv_prob_uniform.__doc__ = mv_prob_uniform.__doc__ @@ -107,1054 +70,10 @@ def event_mv_prob_normal( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - events = as_jax(events) - if isinstance(w_mu, float): w_mu = as_jax(w_mu) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_normal(events, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) event_mv_prob_normal.__doc__ = mv_prob_normal.__doc__ - -if ti is not None: - from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) - - - # ------------- - # CPU function - # ------------- - # For each non-zero event value, it generates a random key using a - # function lfsr88_key and then uses this key to compute random integers - # and update the out array based on the computed indices and weight. - # - # The function is likely designed to be parallelized. - - @ti.kernel - def _event_mv_prob_homo_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_homo_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - if events[i_col]: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - # ------------- - # GPU function - # ------------- - # Contrary to the CPU functions, for each column, - # this function will 32 threads (one warp) to make - # the just-in-time random generation parallelized. - - @ti.kernel - def _event_mv_prob_homo_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_homo_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _reverse(shape): - return shape[::-1] - - - # ------------- - # CPU function - # ------------- - # For each non-zero event value, it generates a random key using a - # function lfsr88_key and then uses this key to compute random integers - # and update the out array based on the computed indices and weight. - # - # The function is likely designed to be parallelized. - - @ti.kernel - def _event_mv_prob_homo_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_homo_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - if events[i_col] != 0.: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - - # ------------- - # GPU function - # ------------- - # Contrary to the CPU functions, for each column, - # this function will 32 threads (one warp) to make - # the just-in-time random generation parallelized. - - @ti.kernel - def _event_mv_prob_homo_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_homo_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _event_mv_prob_homo_jvp_events( - evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(evt_dot, weight, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_homo_jvp_weight( - w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(events, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - - def raw_event_mv_prob_homo( - events: jax.Array, - weight: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_outdim_parallel_bool_p - else: - prim = _event_mv_prob_homo_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_bool_p - else: - prim = _event_mv_prob_homo_p - - return prim(events, - weight, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_homo_jvp_events, - _event_mv_prob_homo_jvp_weight, - None, - None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - - # outdim_parallel = True, events.dtype = jnp.bool_ - _event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu - ) - - # outdim_parallel = False, events.dtype = jnp.bool_ - _event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_bool_cpu, - gpu_kernel=_event_mv_prob_homo_bool_gpu - ) - - # outdim_parallel = True, events.dtype != jnp.bool_ - _event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu - ) - - # outdim_parallel = False, events.dtype != jnp.bool_ - _event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_cpu, - gpu_kernel=_event_mv_prob_homo_gpu - ) - - - @ti.kernel - def _event_mv_prob_uniform_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_uniform_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _event_mv_prob_uniform_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_uniform_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - @ti.kernel - def _event_mv_prob_uniform_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_uniform_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - - @ti.kernel - def _event_mv_prob_uniform_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_uniform_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _event_mv_prob_uniform_jvp_events( - evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_uniform_jvp_w_low( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_uniform_jvp_w_high( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def raw_event_mv_prob_uniform( - events: jax.Array, - w_low: jax.Array, # vector with size 1 - w_high: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_outdim_parallel_bool_p - else: - prim = _event_mv_prob_uniform_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_bool_p - else: - prim = _event_mv_prob_uniform_p - - return prim(events, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_uniform_jvp_events, - _event_mv_prob_uniform_jvp_w_low, - _event_mv_prob_uniform_jvp_w_high, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - - # outdim_parallel = True, events.dtype = jnp.bool_ - _event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu - ) - - # outdim_parallel = False, events.dtype = jnp.bool_ - _event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_bool_gpu - ) - - # outdim_parallel = True, events.dtype != jnp.bool_ - _event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu - ) - - # outdim_parallel = False, events.dtype != jnp.bool_ - _event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_cpu, - gpu_kernel=_event_mv_prob_uniform_gpu - ) - - - @ti.kernel - def _event_mv_prob_normal_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_normal_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _event_mv_prob_normal_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_normal_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - @ti.kernel - def _event_mv_prob_normal_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_normal_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _event_mv_prob_normal_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _event_mv_prob_normal_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _event_mv_prob_normal_jvp_events( - evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_normal_jvp_w_mu( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def _event_mv_prob_normal_jvp_w_sigma( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel - ): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - - def raw_event_mv_prob_normal( - events: jax.Array, - w_mu: jax.Array, # vector with size 1 - w_sigma: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, - ) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_outdim_parallel_bool_p - else: - prim = _event_mv_prob_normal_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_bool_p - else: - prim = _event_mv_prob_normal_p - - return prim(events, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_normal_jvp_events, - _event_mv_prob_normal_jvp_w_mu, - _event_mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - - # outdim_parallel = True, events.dtype = jnp.bool_ - _event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu - ) - - # outdim_parallel = False, events.dtype = jnp.bool_ - _event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_bool_cpu, - gpu_kernel=_event_mv_prob_normal_bool_gpu - ) - - # outdim_parallel = True, events.dtype != jnp.bool_ - _event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu - ) - - # outdim_parallel = False, events.dtype != jnp.bool_ - _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_cpu, - gpu_kernel=_event_mv_prob_normal_gpu - ) \ No newline at end of file diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 296a7994..4d4dd25a 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -4,16 +4,16 @@ import jax import numpy as np -from brainpy._src.dependency_check import import_taichi +from jax import numpy as jnp + from brainpy._src.math import defaults from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array, _get_dtype +from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError -from jax import numpy as jnp -from jax.interpreters import ad +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found -ti = import_taichi(error_if_not_found=False) +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'mv_prob_homo', @@ -83,22 +83,11 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + if bti is None: + raise_braintaichi_not_found() - vector = as_jax(vector) - if isinstance(weight, float): - weight = as_jax(weight, dtype=vector.dtype) - weight = jnp.atleast_1d(as_jax(weight)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - clen = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.asarray(seed, dtype=jnp.uint32) - seed = jnp.atleast_1d(seed) - return raw_mv_prob_homo(vector, weight, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return bti.jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def mv_prob_uniform( @@ -162,22 +151,11 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') + if bti is None: + raise_braintaichi_not_found() - vector = as_jax(vector) - if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) - if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_uniform(vector, w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return bti.jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def mv_prob_normal( @@ -241,22 +219,10 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - vector = as_jax(vector) - if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) - if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_mv_prob_normal(vector, w_mu, w_sigma, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def get_homo_weight_matrix( @@ -290,26 +256,9 @@ def get_homo_weight_matrix( out: Array, ndarray The connection matrix :math:`M`. """ - if isinstance(weight, numbers.Number): - weight = jnp.atleast_1d(jnp.asarray(weight, dtype=defaults.float_)) - else: - raise ValueError(f'weight must be a number type, but get {type(weight)}') - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - r = raw_get_homo_weight_matrix(conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0].astype(jnp.bool_) - r *= weight - if transpose: - return r.transpose() - else: - return r + if bti is None: + raise_braintaichi_not_found() + return bti.get_homo_weight_matrix(weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) def get_uniform_weight_matrix( @@ -348,23 +297,10 @@ def get_uniform_weight_matrix( out: Array, ndarray The weight matrix :math:`M`. """ - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - r = raw_get_uniform_weight_matrix(w_low, w_high, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - if transpose: - return r.transpose() - else: - return r + if bti is None: + raise_braintaichi_not_found() + return bti.get_uniform_weight_matrix(w_low, w_high, conn_prob, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) def get_normal_weight_matrix( @@ -401,912 +337,9 @@ def get_normal_weight_matrix( out: Array, ndarray The weight matrix :math:`M`. """ - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_len = jnp.ceil(1 / conn_prob) * 2 - 1 - conn_len = jnp.asarray(jnp.atleast_1d(conn_len), dtype=jnp.int32) - if seed is None: - with jax.ensure_compile_time_eval(): - seed = np.random.randint(0, int(1e8), 1) - seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - r = raw_get_normal_weight_matrix(w_mu, w_sigma, conn_len, seed, - shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] - if transpose: - return r.transpose() - else: - return r - - -def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p - else: - prim = _mv_prob_uniform_p - - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p - - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def raw_get_homo_weight_matrix( - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - if outdim_parallel: - prim = _get_connect_matrix_outdim_parallel_p - else: - prim = _get_connect_matrix_p - - return prim(conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.int32)], - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def raw_get_uniform_weight_matrix( - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - if outdim_parallel: - prim = _get_uniform_weight_matrix_outdim_parallel_p - else: - prim = _get_uniform_weight_matrix_p - - return prim(w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)], - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def raw_get_normal_weight_matrix( - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - if outdim_parallel: - prim = _get_normal_weight_matrix_outdim_parallel_p - else: - prim = _get_normal_weight_matrix_p - - return prim(w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=shape, dtype=jnp.float32)], - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - - -def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed - else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed - else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _mv_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_low, w_high, clen, seed - else: - dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_low, w_high, clen, seed - else: - assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' - assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed - else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _reverse(shape): - return shape[::-1] - - -if ti is not None: - from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) - - - @ti.kernel - def _mv_prob_homo_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - v = vector[i_col] * weight0 - while i_row < num_row: - out[i_row] += v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_homo_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r * weight0 - - - @ti.kernel - def _mv_prob_homo_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_homo_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += weight0 * r # TODO: warp-level reduction - - - def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - - def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - - # outdim_parallel = True - _mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) - - # outdim_parallel = False - _mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, - gpu_kernel=_mv_prob_homo_gpu) - - - @ti.kernel - def _mv_prob_uniform_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_uniform_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _mv_prob_uniform_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_uniform_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_uniform_jvp_vector, - _mv_prob_uniform_jvp_wlow, - _mv_prob_uniform_jvp_whigh, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - - # outdim_parallel = True - _mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu - ) - - # outdim_parallel = False - _mv_prob_uniform_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_cpu, - gpu_kernel=_mv_prob_uniform_gpu - ) - - - @ti.kernel - def _mv_prob_normal_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_normal_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - - @ti.kernel - def _mv_prob_normal_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _mv_prob_normal_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) - ): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - - def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - - def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_normal_jvp_vector, - _mv_prob_normal_jvp_w_mu, - _mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - - # outdim_parallel = True - _mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_mv_prob_normal_outdim_parallel_gpu - ) - - # outdim_parallel = False - _mv_prob_normal_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_cpu, - gpu_kernel=_mv_prob_normal_gpu - ) - - - @ti.kernel - def _get_connect_matrix( - clen: ti.types.ndarray(), - seed: ti.types.ndarray(), - out: ti.types.ndarray(), - ): - num_row = out.shape[0] - num_col = out.shape[1] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row, i_col] = 1 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _get_connect_matrix_outdim_parallel( - clen: ti.types.ndarray(), - seed: ti.types.ndarray(), - out: ti.types.ndarray(), - ): - num_row = out.shape[0] - num_col = out.shape[1] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - out[i_row, i_col] = 1 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - - - _get_connect_matrix_p = XLACustomOp(cpu_kernel=_get_connect_matrix, gpu_kernel=_get_connect_matrix) - _get_connect_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_connect_matrix_outdim_parallel, - gpu_kernel=_get_connect_matrix_outdim_parallel) - - - @ti.kernel - def _get_uniform_weight_matrix( - w_low: ti.types.ndarray(), - w_high: ti.types.ndarray(), - clen: ti.types.ndarray(), - seed: ti.types.ndarray(), - out: ti.types.ndarray(), - ): - num_row = out.shape[0] - num_col = out.shape[1] - w_low0 = w_low[0] - w_high0 = w_high[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_low0, w_high0) - out[i_row, i_col] = raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _get_uniform_weight_matrix_outdim_parallel( - w_low: ti.types.ndarray(), - w_high: ti.types.ndarray(), - clen: ti.types.ndarray(), - seed: ti.types.ndarray(), - out: ti.types.ndarray(), - ): - num_row = out.shape[0] - num_col = out.shape[1] - w_low0 = w_low[0] - w_high0 = w_high[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_low0, w_high0) - out[i_row, i_col] = raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - - - _get_uniform_weight_matrix_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix, - gpu_kernel=_get_uniform_weight_matrix) - _get_uniform_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_uniform_weight_matrix_outdim_parallel, - gpu_kernel=_get_uniform_weight_matrix_outdim_parallel) - - - @ti.kernel - def _get_normal_weight_matrix( - w_mu: ti.types.ndarray(), - w_sigma: ti.types.ndarray(), - clen: ti.types.ndarray(), - seed: ti.types.ndarray(), - out: ti.types.ndarray(), - ): - num_row = out.shape[0] - num_col = out.shape[1] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row, i_col] = raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - - @ti.kernel - def _get_normal_weight_matrix_outdim_parallel( - w_mu: ti.types.ndarray(), - w_sigma: ti.types.ndarray(), - clen: ti.types.ndarray(), - seed: ti.types.ndarray(), - out: ti.types.ndarray(), - ): - num_row = out.shape[0] - num_col = out.shape[1] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row, i_col] = raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - + if bti is None: + raise_braintaichi_not_found() + return bti.get_normal_weight_matrix(w_mu, w_sigma, conn_prob, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) - _get_normal_weight_matrix_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix, - gpu_kernel=_get_normal_weight_matrix) - _get_normal_weight_matrix_outdim_parallel_p = XLACustomOp(cpu_kernel=_get_normal_weight_matrix_outdim_parallel, - gpu_kernel=_get_normal_weight_matrix_outdim_parallel) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index 01e9acf8..e2c91493 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -7,10 +7,6 @@ from absl.testing import parameterized import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) import platform force_test = False # turn on to force test on windows locally diff --git a/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py index 2b04ab61..0cf00071 100644 --- a/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py +++ b/brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py @@ -6,11 +6,6 @@ from absl.testing import parameterized import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) - import platform force_test = False # turn on to force test on windows locally diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 7720011a..d69def9a 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -7,10 +7,6 @@ from absl.testing import parameterized import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) import platform force_test = False # turn on to force test on windows locally diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py index 7e59e8c0..19160708 100644 --- a/brainpy/_src/math/op_register/__init__.py +++ b/brainpy/_src/math/op_register/__init__.py @@ -3,6 +3,5 @@ compile_cpu_signature_with_numba) from .base import XLACustomOp from .utils import register_general_batching -from .taichi_aot_based import clear_taichi_aot_caches, count_taichi_aot_kernels from .base import XLACustomOp from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 20a48778..a6dd5a5b 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -9,20 +9,16 @@ from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject +is_version_right = False if jax.__version__ >= '0.4.16': from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule - from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, + from braintaichi._primitive._mlir_translation_rule import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) from .cupy_based import ( register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) -else: - from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule - from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule, - register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import ( - register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, - register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) + is_version_right = True + from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp @@ -73,6 +69,8 @@ def __init__( outs: Optional[Callable] = None, name: str = None, ): + if not is_version_right: + raise ImportError('XLA Custom Op is only supported in JAX>=0.4.16') super().__init__(name) # set cpu_kernel and gpu_kernel diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py deleted file mode 100644 index 858f338b..00000000 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ /dev/null @@ -1,526 +0,0 @@ -import contextlib -import hashlib -import inspect -import io -import os -import pathlib -import platform -import re -import shutil -from functools import partial, reduce -from typing import Any, Sequence, Union - -import jax.core -import numpy as np -from jax.interpreters import xla, mlir -from jax.lib import xla_client -from jaxlib.hlo_helpers import custom_call - -from brainpy._src.dependency_check import (import_taichi, - import_brainpylib_cpu_ops, - import_brainpylib_gpu_ops) -from brainpy.errors import PackageMissingError -from .utils import _shape_to_layout - - -taichi_cache_path = None - - -# --- UTILS ### - -# get the path of home directory on Linux, Windows, Mac -def get_home_dir(): - return str(pathlib.Path.home()) - - -# encode a string with md5 -def encode_md5(source: str) -> str: - # create md5 object - md5 = hashlib.md5() - - # encode source - source_encode = source.encode(encoding='utf-8') - - # update md5 object - md5.update(source_encode) - - return md5.hexdigest() - - -# check kernels count -def count_taichi_aot_kernels() -> int: - """ - Count the number of AOT compiled kernels. - - Returns - ------- - kernels_count: int - The number of AOT compiled kernels. - - """ - if not os.path.exists(kernels_aot_path): - return 0 - kernels_count = 0 - dir1 = os.listdir(kernels_aot_path) - for i in dir1: - dir2 = os.listdir(os.path.join(kernels_aot_path, i)) - kernels_count += len(dir2) - return kernels_count - - -def clear_taichi_aot_caches(kernels: Union[str, Sequence[str]] = None): - """ - Clean the cache of the AOT compiled kernels. - - Parameters - ---------- - kernels: str or list of str - The name of the kernel to be cleaned. If None, all the kernels will be cleaned. - """ - if kernels is None: - global taichi_cache_path - if taichi_cache_path is None: - from taichi._lib.utils import import_ti_python_core - taichi_cache_path = import_ti_python_core().get_repo_dir() - # clean taichi cache - if os.path.exists(taichi_cache_path): - shutil.rmtree(taichi_cache_path) - # clean brainpy-taichi AOT cache - if os.path.exists(kernels_aot_path): - shutil.rmtree(kernels_aot_path) - return - if isinstance(kernels, str): - kernels = [kernels] - if not isinstance(kernels, list): - raise TypeError(f'kernels_name must be a list of str, but got {type(kernels)}') - # clear brainpy kernel cache - for kernel_name in kernels: - if os.path.exists(os.path.join(kernels_aot_path, kernel_name)): - shutil.rmtree(os.path.join(kernels_aot_path, kernel_name)) - - -# TODO -# not a very good way -# get source with dependencies -def get_source_with_dependencies(func, visited=None): - if visited is None: - visited = set() - - source = inspect.getsource(func) - if func in visited: - return '' - - visited.add(func) - module = inspect.getmodule(func) - dependent_funcs = re.findall(r'(\w+)\(', source) - - for func_name in dependent_funcs: - dependent_func = getattr(module, func_name, None) - if callable(dependent_func): - source += get_source_with_dependencies(dependent_func, visited) - return source - - -# check if Metal is supported -def is_metal_supported(): - # first check if we are on macOS - if platform.system() != 'Darwin': - return False - if platform.processor() != 'arm': - return False - return True - - -# --- VARIABLES ### -home_path = get_home_dir() -kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels') -is_metal_device = is_metal_supported() - - -# check if a kernel exists in the database -def _check_kernel_exist(source_md5_encode: str) -> bool: - # get the realpath of the kernel - kernel_path = os.path.join(kernels_aot_path, source_md5_encode) - - # check whether the kernel exists - if os.path.exists(kernel_path): - return True - else: - return False - - -# --- KERNEL AOT BUILD ### - - -def _array_to_field(dtype, shape) -> Any: - ti = import_taichi() - if dtype == np.bool_: - dtype = bool - elif dtype == np.int8: - dtype = ti.int8 - elif dtype == np.int16: - dtype = ti.int16 - elif dtype == np.int32: - dtype = ti.int32 - elif dtype == np.int64: - dtype = ti.int64 - elif dtype == np.uint8: - dtype = ti.uint8 - elif dtype == np.uint16: - dtype = ti.uint16 - elif dtype == np.uint32: - dtype = ti.uint32 - elif dtype == np.uint64: - dtype = ti.uint64 - elif dtype == np.float16: - dtype = ti.float16 - elif dtype == np.float32: - dtype = ti.float32 - elif dtype == np.float64: - dtype = ti.float64 - else: - raise NotImplementedError(f'Currently we do not support dtype {dtype} in Taichi. ' - f'If you think it is necessary, please open an issue at ' - f'https://github.com/brainpy/BrainPy/issues/new') - return ti.field(dtype=dtype, shape=shape) - - -# build aot kernel -def _build_kernel( - source_md5_encode: str, - kernel: callable, - ins: dict, - outs: dict, - device: str -): - ti = import_taichi() - - # init arch - if device == 'cpu': - if is_metal_device: - arch = ti.arm64 - device = 'arm64' - else: - arch = ti.x64 - elif device == 'gpu': - arch = ti.cuda - else: - raise ValueError(f'Unknown device: {device}') - with contextlib.redirect_stdout(io.StringIO()): - ti.init(arch=arch) - - # check arch is available - if ti.lang.impl.current_cfg().arch != arch: - raise RuntimeError(f"Arch {arch} is not available") - - # get kernel name - kernel_name = kernel.__name__ - - # replace the name of the func - kernel.__name__ = f'taichi_kernel_{device}' - - # init template_args_dict - template_args_dict = {} - for key, value in ins.items(): - template_args_dict[key] = _array_to_field(value[0], value[1]) - for key, value in outs.items(): - template_args_dict[key] = _array_to_field(value[0], value[1]) - - # make aot dir - kernel_path = os.path.join(kernels_aot_path, source_md5_encode) - os.makedirs(kernel_path, exist_ok=True) - - # compile kernel - mod = ti.aot.Module(arch) - mod.add_kernel(kernel, template_args=template_args_dict) - mod.save(kernel_path) - - # rename kernel name - kernel.__name__ = kernel_name - - -# --- KERNEL CALL PREPROCESS ### - -# convert type to number -type_number_map = { - int: 0, - float: 1, - bool: 2, - np.dtype('int32'): 0, - np.dtype('float32'): 1, - np.dtype('bool'): 2, - np.dtype('uint8'): 3, - np.dtype('uint16'): 4, - np.dtype('uint32'): 5, - np.dtype('uint64'): 6, - np.dtype('int8'): 7, - np.dtype('int16'): 8, - np.dtype('int64'): 9, - np.dtype('float16'): 10, - np.dtype('float64'): 11, -} - - -# preprocess kernel call cpu -def _preprocess_kernel_call_cpu( - source_md5_encode: str, - ins: Sequence, - outs: Sequence, -) -> list: - in_out_info = [] - max_dim_count = 0 - for value in ins: - if value.ndim > max_dim_count: - max_dim_count = value.ndim - - for value in outs: - if value.ndim > max_dim_count: - max_dim_count = value.ndim - - # kernel_path - kernel_path = os.path.join(kernels_aot_path, source_md5_encode) - kernel_path = bytes(kernel_path, encoding='utf-8') + b'\0' - kernel_path = np.array(list(kernel_path), dtype=np.uint8) - - # other args - in_out_num = np.array([len(ins), len(outs), kernel_path.size], dtype=np.uint32) - in_out_type_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32) - in_out_dim_count_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32) - in_out_elem_count_list = np.zeros((len(ins) + len(outs),), dtype=np.uint32) - in_out_shape_list = np.zeros((len(ins) + len(outs), max_dim_count), dtype=np.uint32) - - for i, value in enumerate(ins): - in_out_type_list[i] = type_number_map[value.dtype] - in_out_dim_count_list[i] = value.ndim - in_out_elem_count_list[i] = value.size - for j, dim in enumerate(value.shape): - in_out_shape_list[i, j] = dim - - b = len(ins) - for i, value in enumerate(outs): - in_out_type_list[i + b] = type_number_map[value.dtype] - in_out_dim_count_list[i + b] = value.ndim - in_out_elem_count_list[i + b] = value.size - for j, dim in enumerate(value.shape): - in_out_shape_list[i + b, j] = dim - - in_out_info.append(in_out_num) - in_out_info.append(in_out_type_list) - in_out_info.append(in_out_dim_count_list) - in_out_info.append(in_out_elem_count_list) - in_out_info.append(in_out_shape_list) - in_out_info.append(kernel_path) - - return in_out_info - - -def _preprocess_kernel_call_gpu( - source_md5_encode: str, - ins: Sequence, - outs: Sequence, -) -> bytes: - # if len(ins) + len(outs) > 8: - # raise ValueError('The number of ins and outs must be less than 8!') - kernel_path = os.path.join(kernels_aot_path, source_md5_encode) - - # other args - param_total_num = len(ins) + len(outs) - in_out_num = [len(ins), len(outs)] - in_out_type_list = [0] * param_total_num - in_out_dim_count_list = [0] * param_total_num - in_out_elem_count_list = [0] * param_total_num - in_out_shape_list = [0] * param_total_num * 8 - - for i, value in enumerate(ins): - in_out_type_list[i] = type_number_map[value.dtype] - in_out_dim_count_list[i] = value.ndim - in_out_elem_count_list[i] = value.size - for j, dim in enumerate(value.shape): - in_out_shape_list[i * 8 + j] = dim - - for i, value in enumerate(outs): - in_out_type_list[i + len(ins)] = type_number_map[value.dtype] - in_out_dim_count_list[i + len(ins)] = value.ndim - in_out_elem_count_list[i + len(ins)] = value.size - for j, dim in enumerate(value.shape): - in_out_shape_list[(i + len(ins)) * 8 + j] = dim - - # covert to string - in_out_num_str = ",".join(str(i) for i in in_out_num) - in_out_type_list_str = ",".join(str(i) for i in in_out_type_list) - in_out_dim_count_list_str = ",".join(str(i) for i in in_out_dim_count_list) - in_out_elem_count_list_str = ",".join(str(i) for i in in_out_elem_count_list) - in_out_shape_list_str = ",".join(str(i) for i in in_out_shape_list) - - opaque = (bytes(in_out_num_str, encoding='utf-8') + b';' + - bytes(in_out_type_list_str, encoding='utf-8') + b';' + - bytes(in_out_dim_count_list_str, encoding='utf-8') + b';' + - bytes(in_out_elem_count_list_str, encoding='utf-8') + b';' + - bytes(in_out_shape_list_str, encoding='utf-8') + b';' + - bytes(kernel_path, encoding='utf-8')) - - return opaque - - -def _XlaOp_to_ShapedArray(c, xla_op): - xla_op = c.get_shape(xla_op) - return jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type()) - - -def _mlir_to_ShapedArray(c, op): - return op - - -def _kernel_to_code(kernel, abs_ins, abs_outs, platform): - codes = f'[taichi {platform} kernel]\n' + get_source_with_dependencies(kernel) - codes += '\n[ins]: {}'.format("-".join([f'{v.dtype}[{v.shape}]' for v in abs_ins])) - codes += '\n[outs]: {}'.format("-".join([f'{v.dtype}[{v.shape}]' for v in abs_outs])) - return codes - - -def _compile_kernel(abs_ins, kernel, platform: str, **kwargs): - # input and output abstract information - abs_outs = kwargs['outs'] - - # kernel to code - codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform) - source_md5_encode = os.path.join(kernel.__name__, encode_md5(codes)) - - # create ins, outs dict from kernel's args - in_num = len(abs_ins) - names = tuple(inspect.signature(kernel).parameters.keys()) - in_names, out_names = names[:in_num], names[in_num:] - ins_dict = {key: (abs_ins[i].dtype, abs_ins[i].shape) for i, key in enumerate(in_names)} - outs_dict = {key: (abs_outs[i].dtype, abs_outs[i].shape) for i, key in enumerate(out_names)} - - # build kernels - if not _check_kernel_exist(source_md5_encode): # TODO: more checking - try: - _build_kernel(source_md5_encode, kernel, ins_dict, outs_dict, platform) - except Exception as e: - try: - os.removedirs(os.path.join(kernels_aot_path, source_md5_encode)) - except Exception: - raise RuntimeError(f'Failed to preprocess info to build kernel:\n\n {codes}') from e - raise RuntimeError(f'Failed to build kernel:\n\n {codes}') from e - - # returns - if platform in ['gpu', 'cuda']: - import_brainpylib_gpu_ops() - opaque = _preprocess_kernel_call_gpu(source_md5_encode, abs_ins, abs_outs) - return opaque - elif platform == 'cpu': - import_brainpylib_cpu_ops() - in_out_info = _preprocess_kernel_call_cpu(source_md5_encode, abs_ins, abs_outs) - return in_out_info - else: - raise ValueError(f'Unknown platform: {platform}') - - -def _get_abs_ins(c, ins): - abs_ins = [] - for v in ins: - xla_op = c.get_shape(v) - abs_ins.append(jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type())) - return abs_ins - - -def _taichi_xla_cpu_translation_rule(kernel, c, *ins, **kwargs): - in_out_info = _compile_kernel(_get_abs_ins(c, ins), kernel, 'cpu', **kwargs) - ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins) - if is_metal_device: - fn = b'taichi_kernel_aot_call_cpu_arm64' - else: - fn = b'taichi_kernel_aot_call_cpu' - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=ins, - operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), - shape_with_layout=xla_client.Shape.tuple_shape( - [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) - for value in kwargs['outs']] - ), - ) - - -def _taichi_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): - opaque = _compile_kernel(_get_abs_ins(c, ins), kernel, 'gpu', **kwargs) - return xla_client.ops.CustomCallWithLayout( - c, - b'taichi_kernel_aot_call_gpu', - operands=ins, - operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), - shape_with_layout=xla_client.Shape.tuple_shape( - [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) - for value in kwargs['outs']] - ), - opaque=opaque, - ) - - -def register_taichi_aot_xla_cpu_translation_rule(primitive, cpu_kernel): - xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_xla_cpu_translation_rule, cpu_kernel) - - -def register_taichi_aot_xla_gpu_translation_rule(primitive, gpu_kernel): - xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_xla_gpu_translation_rule, gpu_kernel) - - -def _taichi_mlir_cpu_translation_rule(kernel, c, *ins, **kwargs): - in_out_info = _compile_kernel(c.avals_in, kernel, 'cpu', **kwargs) - ins = [mlir.ir_constant(v) for v in in_out_info] + list(ins) - input_layouts = [_shape_to_layout(arr.shape) for arr in in_out_info] + [_shape_to_layout(a.shape) for a in c.avals_in] - output_layouts = tuple([_shape_to_layout(out.shape) for out in c.avals_out]) - result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] - if is_metal_device: - if len(output_layouts) == 1: - fn = 'taichi_kernel_aot_call_cpu_arm64_single_result' - else: - fn = 'taichi_kernel_aot_call_cpu_arm64' - else: - if len(output_layouts) == 1: - fn = 'taichi_kernel_aot_call_cpu_single_result' - else: - fn = 'taichi_kernel_aot_call_cpu' - return custom_call( - call_target_name=fn, - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - has_side_effect=False, - ).results - - -def _taichi_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): - opaque = _compile_kernel(c.avals_in, kernel, 'gpu', **kwargs) - input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] - result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] - output_layouts = [_shape_to_layout(out.shape) for out in c.avals_out] - return custom_call( - call_target_name='taichi_kernel_aot_call_gpu', - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - backend_config=opaque, - has_side_effect=False, - ).results - - -def register_taichi_aot_mlir_cpu_translation_rule(primitive, cpu_kernel): - if import_taichi(error_if_not_found=False) is None: - raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') - - rule = partial(_taichi_mlir_cpu_translation_rule, cpu_kernel) - mlir.register_lowering(primitive, rule, platform='cpu') - - -def register_taichi_aot_mlir_gpu_translation_rule(primitive, gpu_kernel): - if import_taichi(error_if_not_found=False) is None: - raise PackageMissingError.by_purpose("taichi", 'register taichi AOT based translation rule') - - rule = partial(_taichi_mlir_gpu_translation_rule, gpu_kernel) - mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py deleted file mode 100644 index 85401d99..00000000 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ /dev/null @@ -1,68 +0,0 @@ -import jax -import jax.numpy as jnp -import pytest - -import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -import platform -force_test = False # turn on to force test on windows locally -if platform.system() == 'Windows' and not force_test: - pytest.skip('skip windows', allow_module_level=True) - -ti = import_taichi(error_if_not_found=False) -if ti is None: - pytest.skip('no taichi', allow_module_level=True) - -bm.set_platform('cpu') - -@ti.func -def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32: - return weight[None] - - -@ti.func -def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): - out[index] += weight_val - - -@ti.kernel -def event_ell_cpu(indices: ti.types.ndarray(ndim=2), - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=0), - out: ti.types.ndarray(ndim=1)): - weight_val = get_weight(weight) - num_rows, num_cols = indices.shape - ti.loop_config(serialize=True) - for i in range(num_rows): - if vector[i]: - for j in range(num_cols): - update_output(out, indices[i, j], weight_val) - -@ti.kernel -def event_ell_gpu(indices: ti.types.ndarray(ndim=2), - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=0), - out: ti.types.ndarray(ndim=1)): - weight_val = get_weight(weight) - num_rows, num_cols = indices.shape - for i in range(num_rows): - if vector[i]: - for j in range(num_cols): - update_output(out, indices[i, j], weight_val) - -prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu) - - -def test_taichi_op_register(): - s = 1000 - indices = bm.random.randint(0, s, (s, 1000)) - vector = bm.random.rand(s) < 0.1 - - out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - print(out) - -# test_taichi_op_register() diff --git a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py b/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py deleted file mode 100644 index b534435d..00000000 --- a/brainpy/_src/math/op_register/tests/test_taichi_clean_cache.py +++ /dev/null @@ -1,60 +0,0 @@ -import jax -import jax.numpy as jnp - -import brainpy.math as bm -import taichi as ti - -from brainpy._src.dependency_check import import_taichi -ti = import_taichi(error_if_not_found=False) -if ti is None: - import pytest - pytest.skip('no taichi', allow_module_level=True) - - -@ti.func -def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: - return weight[0] - - -@ti.func -def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): - out[index] += weight_val - - -@ti.kernel -def event_ell_cpu(indices: ti.types.ndarray(ndim=2), - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - weight_val = get_weight(weight) - num_rows, num_cols = indices.shape - ti.loop_config(serialize=True) - for i in range(num_rows): - if vector[i]: - for j in range(num_cols): - update_output(out, indices[i, j], weight_val) - - -prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) - - -def test_taichi_clean_cache(): - s = 1000 - indices = bm.random.randint(0, s, (s, 100)) - vector = bm.random.rand(s) < 0.1 - weight = bm.array([1.0]) - - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - - print(out) - bm.clear_buffer_memory() - - print('kernels: ', bm.count_taichi_aot_kernels()) - - bm.clear_taichi_aot_caches() - - print('kernels: ', bm.count_taichi_aot_kernels()) - -# test_taichi_clean_cache() diff --git a/brainpy/_src/math/sparse/coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py index 2885d946..c9a46ff6 100644 --- a/brainpy/_src/math/sparse/coo_mv.py +++ b/brainpy/_src/math/sparse/coo_mv.py @@ -1,18 +1,14 @@ # -*- coding: utf-8 -*- -import warnings -from functools import partial from typing import Union, Tuple -import numpy as np -from jax import core, numpy as jnp, dtypes, default_backend -from jax.interpreters import ad, mlir -from jaxlib import gpu_sparse +from jax import numpy as jnp -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import register_general_batching +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found + +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'coomv', @@ -65,136 +61,17 @@ def coomv( An array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ - - data = jnp.atleast_1d(as_jax(data)) - row = as_jax(row) - col = as_jax(col) - vector = as_jax(vector) - - if method == 'cusparse': - if default_backend() != 'cpu': - if data.shape[0] == 1: - data = jnp.ones(row.shape, dtype=data.dtype) * data - if row.dtype in [jnp.uint32, jnp.uint64]: - row = jnp.asarray(row, dtype=dtypes.canonicalize_dtype(jnp.int64)) - if col.dtype in [jnp.uint32, jnp.uint64]: - col = jnp.asarray(col, dtype=dtypes.canonicalize_dtype(jnp.int64)) - return _coomv_cusparse_p.bind(data, - row, - col, - vector, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - else: - raise ValueError - - -# -------------------------------------------------------------------- -# cusparse_coo_matvec - - -def _coomv_impl(data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - v = jnp.asarray(v) - if transpose: - row, col = col, row - out_shape = shape[1] if transpose else shape[0] - dv = data * v[col] - return jnp.zeros(out_shape, dv.dtype).at[row].add(dv) - - -def _coomv_abstract_eval(data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - assert data.shape == row.shape == col.shape - assert data.dtype == v.dtype - assert row.dtype == col.dtype - assert len(shape) == 2 - assert v.ndim == 1 - assert v.shape[0] == (shape[0] if transpose else shape[1]) - out_shape = shape[1] if transpose else shape[0] - return core.ShapedArray((out_shape,), data.dtype) - - -_coo_matvec_lowering = mlir.lower_fun(_coomv_impl, multiple_results=False) - - -def _coomv_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *, - shape, rows_sorted, cols_sorted, transpose): - data_aval, row_aval, _, x_aval = ctx.avals_in - dtype = data_aval.dtype - if dtype not in [np.float32, np.float64, np.complex64, np.complex128]: - warnings.warn(f"cusparse_coo_matvec cusparse/hipsparse lowering not available for dtype={dtype}. " - "Falling back to default implementation.", UserWarning) - return _coo_matvec_lowering(ctx, data, row, col, v, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - if rows_sorted: - shape = shape - elif cols_sorted: - row, col = col, row - transpose = not transpose - shape = shape[::-1] - else: - warnings.warn("cusparse_coo_matvec GPU lowering requires matrices with sorted rows or sorted cols. " - "To sort the rows in your matrix, use e.g. mat = mat._sort_rows(). Falling " - "back to the default implementation.", UserWarning) - return _coo_matvec_lowering(ctx, data, row, col, v, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - return [coo_matvec_mhlo(data, row, col, v, - shape=shape, - transpose=transpose, - index_dtype=row_aval.dtype, - data_dtype=dtype, - x_dtype=x_aval.dtype)] - - -def _coomv_jvp_mat(data_dot, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - return _coomv_cusparse_p.bind(data_dot, row, col, v, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - -def _coomv_jvp_vec(v_dot, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - return _coomv_cusparse_p.bind(data, row, col, v_dot, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=transpose) - - -def _coomv_transpose(ct, data, row, col, v, *, shape, rows_sorted, cols_sorted, transpose): - assert not ad.is_undefined_primal(row) - assert not ad.is_undefined_primal(col) - - if ad.is_undefined_primal(v): - return data, row, col, _coomv_cusparse_p.bind(data, row, col, ct, - shape=shape, - rows_sorted=rows_sorted, - cols_sorted=cols_sorted, - transpose=not transpose) - else: - return ct[row] * v[col], row, col, v - - -_coomv_cusparse_p = core.Primitive('cusparse_coo_matvec') -_coomv_cusparse_p.def_abstract_eval(_coomv_abstract_eval) -_coomv_cusparse_p.def_impl(_coomv_impl) -ad.defjvp(_coomv_cusparse_p, _coomv_jvp_mat, None, None, _coomv_jvp_vec) -ad.primitive_transposes[_coomv_cusparse_p] = _coomv_transpose -mlir.register_lowering(_coomv_cusparse_p, _coo_matvec_lowering) -mlir.register_lowering(_coomv_cusparse_p, - partial(_coomv_gpu_lowering, gpu_sparse.cuda_coo_matvec), - platform='cuda') -register_general_batching(_coomv_cusparse_p) - - + if bti is None: + raise_braintaichi_not_found() + + return bti.coomv( + data=data, + row=row, + col=col, + vector=vector, + shape=shape, + rows_sorted=rows_sorted, + cols_sorted=cols_sorted, + transpose=transpose, + method=method + ) diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index 47c24fa4..4d5b0d6c 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -3,19 +3,12 @@ from typing import Union, Tuple -import jax -import numpy as np from jax import numpy as jnp -from jax.experimental.sparse import csr -from jax.interpreters import ad -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (XLACustomOp, register_general_batching) -from brainpy.errors import PackageMissingError +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found -ti = import_taichi(error_if_not_found=False) +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'csrmm', @@ -48,180 +41,7 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product. """ - return raw_csrmm_taichi(data, indices, indptr, matrix, shape=shape, transpose=transpose)[0] + if bti is None: + raise_braintaichi_not_found() - -def raw_csrmm_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - matrix: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -): - assert len(shape) == 2 - - indices = as_jax(indices) - indptr = as_jax(indptr) - matrix = as_jax(matrix) - data = jnp.atleast_1d(data) - - if matrix.dtype == jnp.bool_: - matrix = as_jax(matrix, dtype=data.dtype) - - if data.dtype != matrix.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {matrix.dtype}.') - assert matrix.ndim == 2 - - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - assert indptr.shape[0] == shape[0] + 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - out_shape = shape[1] if transpose else shape[0] - result_shape = (out_shape, matrix.shape[1]) - - assert matrix.shape[0] == (shape[0] if transpose else shape[1]) - - if indices.shape[0] == 0: - return [jnp.zeros(result_shape, dtype=data.dtype), ] - - # homo -> taichi, - # heter -> cusparse - if data.shape[0] != 1: - return [_csr_matmat_cusparse_p.bind(data, indices, indptr, matrix, shape=shape, transpose=transpose), ] - else: - if ti is None: - raise PackageMissingError.by_purpose('taichi', 'customzied sparse matrix multiplication') - if transpose: - prim = _csr_matmat_transpose_homo_p - else: - prim = _csr_matmat_homo_p - r = prim(indices, - indptr, - matrix, - outs=[jax.ShapeDtypeStruct(result_shape, dtype=matrix.dtype)], - transpose=transpose, - shape=shape) - return [r[0] * data] - - -# taichi kernels -if ti is not None: - # @ti.kernel - # def _csr_matmat_transpose_heter(values: ti.types.ndarray(ndim=1), - # col_indices: ti.types.ndarray(ndim=1), - # row_ptr: ti.types.ndarray(ndim=1), - # matrix: ti.types.ndarray(ndim=2), - # out: ti.types.ndarray(ndim=2)): - # for row_i in range(row_ptr.shape[0] - 1): - # for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - # col = col_indices[i] - # for j in range(out.shape[1]): - # out[col, j] += values[row_i] * matrix[row_i, j] - # - # @ti.kernel - # def _csr_matmat_heter(values: ti.types.ndarray(ndim=1), - # col_indices: ti.types.ndarray(ndim=1), - # row_ptr: ti.types.ndarray(ndim=1), - # matrix: ti.types.ndarray(ndim=2), - # out: ti.types.ndarray(ndim=2)): - # for row_i, col_k in ti.ndrange(out.shape[0], out.shape[1]): - # r = 0. - # for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - # r += values[j] * matrix[col_indices[j], col_k] - # out[row_i, col_k] = r - # - # # transpose heter - # _csr_matmat_transpose_heter_p = _define_op(cpu_kernel=_csr_matmat_transpose_heter, - # gpu_kernel=_csr_matmat_transpose_heter) - # - # # no transpose heter - # _csr_matmat_heter_p = _define_op(cpu_kernel=_csr_matmat_heter, - # gpu_kernel=_csr_matmat_heter) - - @ti.kernel - def _csr_matmat_transpose_homo_cpu(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - # matrix: (k, n) - # sparse matrix: (m, k) - n = out.shape[1] - m = row_ptr.shape[0] - 1 - for j in range(n): # parallize along the n dimension - for row_i in range(m): # loop along the m dimension - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[i], j] += matrix[row_i, j] - - - @ti.kernel - def _csr_matmat_transpose_homo_gpu(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - m = row_ptr.shape[0] - 1 - n = matrix.shape[1] - for j, row_i in ti.ndrange(n, m): # paralleize along the (n and m) dimensions - for i in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[i], j] += matrix[row_i, j] - - - @ti.kernel - def _csr_matmat_homo(col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - matrix: ti.types.ndarray(ndim=2), - out: ti.types.ndarray(ndim=2)): - # matrix: (k, n) - # sparse matrix: (m, k) - m, n = out.shape - for row_i, col_k in ti.ndrange(m, n): - r = 0. - for row_j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += matrix[col_indices[row_j], col_k] - out[row_i, col_k] = r - - - def _csr_matmat_jvp_matrix(mat_dot, col_indices, row_ptr, matrix, *, outs, transpose, shape): - if transpose: - return _csr_matmat_transpose_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs) - else: - return _csr_matmat_homo_p(col_indices, row_ptr, mat_dot, shape=shape, transpose=transpose, outs=outs) - - - def _csr_matmat_transpose( - ct, col_indices, row_ptr, matrix, *, outs, transpose, shape, - ): - if ad.is_undefined_primal(col_indices) or ad.is_undefined_primal(row_ptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - assert ad.is_undefined_primal(matrix) - ct_matrix = raw_csrmm_taichi(jnp.ones(1), col_indices, row_ptr, ct[0], - shape=shape, - transpose=not transpose) - return col_indices, row_ptr, (ad.Zero(matrix) if type(ct[0]) is ad.Zero else ct_matrix[0]) - - - def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(None, None, _csr_matmat_jvp_matrix) - prim.def_transpose_rule(_csr_matmat_transpose) - return prim - - - # transpose homo - _csr_matmat_transpose_homo_p = _define_op(cpu_kernel=_csr_matmat_transpose_homo_cpu, - gpu_kernel=_csr_matmat_transpose_homo_gpu) - - # no transpose homo - _csr_matmat_homo_p = _define_op(cpu_kernel=_csr_matmat_homo, gpu_kernel=_csr_matmat_homo) - - # heter CUSPARSE - _csr_matmat_cusparse_p = csr.csr_matmat_p - register_general_batching(_csr_matmat_cusparse_p) + return bti.csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py index 6eaf6b79..c39744bb 100644 --- a/brainpy/_src/math/sparse/csr_mv.py +++ b/brainpy/_src/math/sparse/csr_mv.py @@ -3,20 +3,12 @@ from typing import Union, Tuple -import jax from jax import numpy as jnp -from jax.experimental.sparse import csr -from jax.interpreters import ad -import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import (register_general_batching, XLACustomOp) -from brainpy._src.math.sparse.utils import csr_to_coo -from brainpy.errors import PackageMissingError -ti = import_taichi(error_if_not_found=False) +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'csrmv', @@ -69,257 +61,8 @@ def csrmv( The array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ + if bti is None: + raise_braintaichi_not_found() - data = jnp.atleast_1d(as_jax(data)) - indices = as_jax(indices) - indptr = as_jax(indptr) - vector = as_jax(vector) + return bti.csrmv(data, indices, indptr, vector, shape=shape, transpose=transpose) - if vector.dtype == jnp.bool_: - vector = as_jax(vector, dtype=data.dtype) - - if data.dtype not in [jnp.float16, jnp.float32, jnp.float64]: - raise TypeError('Only support float16, float32 or float64 type. ' - f'But we got {data.dtype}.') - if data.dtype != vector.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {vector.dtype}.') - assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - if not jnp.issubdtype(indices.dtype, jnp.integer): - raise ValueError('indices should be a 1D vector with integer type.') - if not jnp.issubdtype(indptr.dtype, jnp.integer): - raise ValueError('indptr should be a 1D vector with integer type.') - - # if the shape of indices is (0,), then we return a zero vector - if indices.shape[0] == 0: - return jnp.zeros(shape[1] if transpose else shape[0], dtype=data.dtype) - - return raw_csrmv_taichi(data, indices, indptr, vector, shape=shape, transpose=transpose)[0] - - -def raw_csrmv_taichi( - data: Union[float, jnp.ndarray, Array], - indices: Union[jnp.ndarray, Array], - indptr: Union[jnp.ndarray, Array], - vector: Union[jnp.ndarray, Array], - *, - shape: Tuple[int, int], - transpose: bool = False, -): - if ti is None: - raise PackageMissingError.by_purpose('taichi', purpose='customized operators') - - out_shape = shape[1] if transpose else shape[0] - if data.shape[0] != 1: - if bm.get_platform() == 'gpu': - return [_csr_matvec_cusparse_p.bind(data, indices, indptr, vector, shape=shape, transpose=transpose)] - else: - if transpose: - prim = _csr_matvec_transpose_heter_p - else: - prim = _csr_matvec_heter_p - else: - if transpose: - prim = _csr_matvec_transpose_homo_p - else: - prim = _csr_matvec_homo_p - - return prim(data, - indices, - indptr, - vector, - outs=[jax.ShapeDtypeStruct((out_shape,), dtype=data.dtype)], - transpose=transpose, - shape=shape) - - -if ti is not None: - - # ------------- - # CPU operators - # ------------- - @ti.kernel - def _sparse_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += value * vector[row_i] - - - @ti.kernel - def _sparse_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - out[col_indices[j]] += vector[row_i] * values[j] - - - @ti.kernel - def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += vector[col_indices[j]] - out[row_i] = r * value - - - @ti.kernel - def _sparse_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(row_ptr.shape[0] - 1): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - out[row_i] = r - - - # ------------- - # GPU operators - # ------------- - - @ti.kernel - def _sparse_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += value * vector[row_i] - j += 32 - - - @ti.kernel - def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += vector[col_indices[j]] - j += 32 - out[row_i] += value * r - - - @ti.kernel - def _sparse_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - out[col_indices[j]] += values[j] * vector[row_i] - j += 32 - - - @ti.kernel - def _sparse_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - col_indices: ti.types.ndarray(ndim=1), - row_ptr: ti.types.ndarray(ndim=1), - vector: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((row_ptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = row_ptr[row_i] + index - end_index = row_ptr[row_i + 1] - while j < end_index: - r += values[j] * vector[col_indices[j]] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - - def _sparse_csr_matvec_jvp_values(val_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(val_dot, col_indices, row_ptr, vector, shape=shape, transpose=transpose) - - - def _sparse_csr_matvec_jvp_vector(vec_dot, values, col_indices, row_ptr, vector, *, outs, transpose, shape): - return raw_csrmv_taichi(values, col_indices, row_ptr, vec_dot, shape=shape, transpose=transpose) - - - def _sparse_csr_matvec_transpose( - ct, data, indices, indptr, vector, *, outs, transpose, shape, - ): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(vector): - ct_vector = raw_csrmv_taichi(data, indices, indptr, ct[0], shape=shape, transpose=not transpose)[0] - return data, indices, indptr, (ad.Zero(vector) if type(ct[0]) is ad.Zero else ct_vector) - - else: - if type(ct[0]) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = raw_csrmv_taichi(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose)[0] - ct_data = jnp.inner(ct[0], ct_data) - else: - row, col = csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[0][col] if transpose else vector[col] * ct[0][row] - - return ct_data, indices, indptr, vector - - - def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_sparse_csr_matvec_jvp_values, None, None, _sparse_csr_matvec_jvp_vector) - prim.def_transpose_rule(_sparse_csr_matvec_transpose) - return prim - - - # transpose homo - _csr_matvec_transpose_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_homo_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_homo_gpu) - - # no transpose homo - _csr_matvec_homo_p = _define_op(cpu_kernel=_sparse_csr_matvec_homo_cpu, - gpu_kernel=_sparse_csr_matvec_homo_gpu) - - # transpose heter - _csr_matvec_transpose_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_transpose_heter_cpu, - gpu_kernel=_sparse_csr_matvec_transpose_heter_gpu) - - # no transpose heter - _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) - - # heter cusparse - _csr_matvec_cusparse_p = csr.csr_matvec_p - register_general_batching(_csr_matvec_cusparse_p) diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 25738db8..61032cf2 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -8,9 +8,6 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) import platform force_test = False # turn on to force test on windows locally diff --git a/brainpy/_src/math/tests/test_tifunc.py b/brainpy/_src/math/tests/test_tifunc.py index db6e7deb..5bf0a0ad 100644 --- a/brainpy/_src/math/tests/test_tifunc.py +++ b/brainpy/_src/math/tests/test_tifunc.py @@ -9,11 +9,6 @@ import matplotlib.pyplot as plt import os -from brainpy._src.dependency_check import import_taichi - -ti = import_taichi(error_if_not_found=False) -if ti is None: - pytest.skip('no taichi', allow_module_level=True) bm.set_platform('cpu') diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py deleted file mode 100644 index 9cfd39e1..00000000 --- a/brainpy/_src/math/tifunc.py +++ /dev/null @@ -1,345 +0,0 @@ -from brainpy._src.dependency_check import import_taichi, raise_taichi_not_found -from . import defaults - -ti = import_taichi(error_if_not_found=False) - -__all__ = [ - # taichi function for other utilities - 'warp_reduce_sum', - - # taichi functions for random number generator with LFSR88 algorithm - 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn', - 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand', - - # taichi functions for random number generator with LFSR113 algorithm - 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn', - 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', -] - -if ti is not None: - - ############################################# - # Random Number Generator: LFSR88 algorithm # - ############################################# - - @ti.func - def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - - This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. - - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c - - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3 MUST be larger than - 1, 7, and 15 respectively. - */ - - Args: - seed: int. The seed value for the random number generator. - - Returns: - ti.math.uvec4: The random key for the LFSR88 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) - - - @ti.func - def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - - Args: - key: The state value for the random number generator. - - Returns: - ti.math.uvec4: The next random key. - """ - b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) - s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b - b = ((key[1] << 2) ^ key[1]) >> 25 - s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b - b = ((key[2] << 3) ^ key[2]) >> 11 - s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b - return ti.math.uvec4(s1, s2, s3, b) - - - @ti.func - def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. - - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ - - key, r = lfsr88_randn(key, epsilon) - return key, mu + sigma * r - - - @ti.func - def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with the standard normal distribution using the LFSR88 algorithm. - - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). - - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method - - """ - - key, u1 = lfsr88_rand(key) - key, u2 = lfsr88_rand(key) - - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) - - return key, z2 - - - @ti.func - def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. - - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) - - - @ti.func - def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): - key = lfsr88_next_key(key) - return key, dtype(key[0] ^ key[1] ^ key[2]) - - - @ti.func - def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. - - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) - - - @ti.func - def lfsr88_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. - - Args: - key: The state value used for random number generation. - """ - key = lfsr88_next_key(key) - return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - - - ############################################## - # Random Number Generator: LFSR113 algorithm # - ############################################## - - @ti.func - def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - - This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. - - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c - - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3, s4 MUST be larger than - 1, 7, 15, and 127 respectively. - */ - - Args: - seed: int. The seed value for the random number generator. - - Returns: - ti.math.uvec4: The random key for the LFSR113 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) - - - @ti.func - def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - - Args: - key: The state value for the random number generator. - - Returns: - ti.math.uvec4: The next random key. - """ - z1 = key[0] - z2 = key[1] - z3 = key[2] - z4 = key[3] - b = ((z1 << 6) ^ z1) >> 13 - z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) - b = ((z2 << 2) ^ z2) >> 27 - z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) - b = ((z3 << 13) ^ z3) >> 21 - z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) - b = ((z4 << 3) ^ z4) >> 12 - z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) - return ti.math.uvec4(z1, z2, z3, z4) - - - @ti.func - def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. - - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ - - key, r = lfsr113_randn(key, epsilon) - return key, ti.cast(mu + sigma * r, defaults.ti_float) - - - @ti.func - def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with standard normal distribution using the LFSR113 algorithm. - - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). - - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method - - """ - - key, u1 = lfsr113_rand(key) - key, u2 = lfsr113_rand(key) - - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) - - return key, z2 - - - @ti.func - def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. - - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr113_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) - - - @ti.func - def lfsr113_randint(key: ti.types.vector(4, ti.u32)): - key = lfsr113_next_key(key) - return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) - - - @ti.func - def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. - - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) - - - @ti.func - def lfsr113_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. - - Args: - key: The state value used for random number generation. - """ - key = lfsr113_next_key(key) - return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - - - ########################### - # Reductions: warp reduce # - ########################### - - @ti.func - def warp_reduce_sum_all(val): - """ - Warp reduce sum. - - Args: - val (float): The value to be reduced. - - Returns: - float: The reduced value. - """ - for i in ti.static(range(1, 32)): - val += ti.static(ti.simt.warp.shfl_xor(val, i)) - return val - - - @ti.func - def warp_reduce_sum(val): - """ - Warp reduce sum. - - Args: - val (float): The value to be reduced. - - Returns: - float: The reduced value. - """ - for offset in ti.static((16, 8, 4, 2, 1)): - val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) - return val - - -else: - for func in __all__: - globals()[func] = raise_taichi_not_found \ No newline at end of file diff --git a/brainpy/_src/tests/test_dyn_runner.py b/brainpy/_src/tests/test_dyn_runner.py index 6f2411ee..037f283a 100644 --- a/brainpy/_src/tests/test_dyn_runner.py +++ b/brainpy/_src/tests/test_dyn_runner.py @@ -5,11 +5,6 @@ import brainpy as bp import brainpy.math as bm -from brainpy._src.dependency_check import import_taichi - -if import_taichi(error_if_not_found=False) is None: - pytest.skip('no taichi', allow_module_level=True) - class TestDSRunner(unittest.TestCase): def test1(self): diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 08a070f0..139ec08a 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -33,8 +33,6 @@ from . import linalg from . import random -# taichi operations -from . import tifunc # others from . import sharding diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py index f383c1a2..8ec7f5e1 100644 --- a/brainpy/math/op_register.py +++ b/brainpy/math/op_register.py @@ -2,8 +2,6 @@ from brainpy._src.math.op_register import ( CustomOpByNumba, compile_cpu_signature_with_numba, - clear_taichi_aot_caches, - count_taichi_aot_kernels, ) from brainpy._src.math.op_register.base import XLACustomOp diff --git a/brainpy/math/tifunc.py b/brainpy/math/tifunc.py deleted file mode 100644 index bea49c22..00000000 --- a/brainpy/math/tifunc.py +++ /dev/null @@ -1,25 +0,0 @@ -# -*- coding: utf-8 -*- - -from brainpy._src.math.tifunc import ( - - # warp reduction primitives - warp_reduce_sum, - - # random number generator - lfsr88_key, - lfsr88_next_key, - lfsr88_normal, - lfsr88_randn, - lfsr88_random_integers, - lfsr88_randint, - lfsr88_uniform, - lfsr88_rand, - lfsr113_key, - lfsr113_next_key, - lfsr113_normal, - lfsr113_randn, - lfsr113_random_integers, - lfsr113_randint, - lfsr113_uniform, - lfsr113_rand -) diff --git a/requirements-dev.txt b/requirements-dev.txt index 754073f4..1ad33b04 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,7 @@ matplotlib msgpack tqdm pathos -taichi==1.7.0 +braintaichi numba braincore braintools diff --git a/requirements-doc.txt b/requirements-doc.txt index e607c26c..5c6d440e 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -5,7 +5,7 @@ matplotlib numpy scipy numba -taichi==1.7.0 +braintaichi # document requirements pandoc diff --git a/setup.py b/setup.py index 55f948e4..84ac38c1 100644 --- a/setup.py +++ b/setup.py @@ -68,9 +68,9 @@ 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', ], extras_require={ - 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'], - 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib', 'numba', 'taichi==1.7.0'], - 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib', 'numba', 'taichi==1.7.0'], + 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'braintaichi'], + 'cuda11': ['jaxlib[cuda11_pip]', 'brainpylib', 'numba', 'braintaichi'], + 'cuda12': ['jaxlib[cuda12_pip]', 'brainpylib', 'numba', 'braintaichi'], 'tpu': ['jaxlib[tpu]', 'numba',], 'cpu_mini': ['jaxlib>=0.4.13'], 'cuda11_mini': ['jaxlib[cuda11_pip]'],