From 55c69bb3659f555263096b3a118fccce73600fd7 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Sat, 17 Feb 2024 20:32:11 +0800 Subject: [PATCH] try to remove hard dependency with taichi and numba --- brainpy/_src/math/defaults.py | 17 +- brainpy/_src/math/environment.py | 21 +- brainpy/_src/math/event/__init__.py | 1 - brainpy/_src/math/event/_csr_matvec.py | 1226 +++----- brainpy/_src/math/event/_info_collection.py | 198 -- .../tests/event_info_VS_jax_operators.py | 275 -- .../_src/math/event/tests/test_event_csrmv.py | 7 + .../math/event/tests/test_event_csrmv_old.py | 8 +- brainpy/_src/math/event/tests/test_info.py | 62 - .../_src/math/event/tests/test_info_gpu.py | 14 - brainpy/_src/math/index_tricks.py | 305 -- brainpy/_src/math/jitconn/_event_matvec.py | 2621 +++++++---------- brainpy/_src/math/jitconn/_matvec.py | 2065 ++++--------- .../math/jitconn/tests/test_event_matvec.py | 6 + .../_src/math/jitconn/tests/test_matvec.py | 5 + brainpy/_src/math/op_register/numba_based.py | 1 - brainpy/_src/math/sparse/__init__.py | 4 +- brainpy/_src/math/sparse/tests/test_csrmv.py | 6 +- .../_src/math/sparse/tests/test_csrmv_old.py | 6 +- brainpy/_src/math/tifunc.py | 538 ++-- brainpy/errors.py | 15 +- brainpy/math/event.py | 1 - brainpy/math/sparse.py | 1 - requirements.txt | 2 - setup.py | 11 +- 25 files changed, 2318 insertions(+), 5098 deletions(-) delete mode 100644 brainpy/_src/math/event/_info_collection.py delete mode 100644 brainpy/_src/math/event/tests/event_info_VS_jax_operators.py delete mode 100644 brainpy/_src/math/event/tests/test_info.py delete mode 100644 brainpy/_src/math/event/tests/test_info_gpu.py delete mode 100644 brainpy/_src/math/index_tricks.py diff --git a/brainpy/_src/math/defaults.py b/brainpy/_src/math/defaults.py index 19aca92cf..dae0f1bcd 100644 --- a/brainpy/_src/math/defaults.py +++ b/brainpy/_src/math/defaults.py @@ -24,15 +24,20 @@ # '''Default integer data type.''' int_ = jnp.int64 if config.read('jax_enable_x64') else jnp.int32 -# '''Default integer data type in Taichi.''' -ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 - # '''Default float data type.''' float_ = jnp.float64 if config.read('jax_enable_x64') else jnp.float32 -# '''Default float data type in Taichi.''' -ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 - # '''Default complex data type.''' complex_ = jnp.complex128 if config.read('jax_enable_x64') else jnp.complex64 + +if ti is not None: + # '''Default integer data type in Taichi.''' + ti_int = ti.int64 if config.read('jax_enable_x64') else ti.int32 + + # '''Default float data type in Taichi.''' + ti_float = ti.float64 if config.read('jax_enable_x64') else ti.float32 + +else: + ti_int = None + ti_float = None \ No newline at end of file diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 1c8b98a3b..757c19b8d 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -416,13 +416,16 @@ def set_float(dtype: type): """ if dtype in [jnp.float16, 'float16', 'f16']: defaults.__dict__['float_'] = jnp.float16 - defaults.__dict__['ti_float'] = ti.float16 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float16 elif dtype in [jnp.float32, 'float32', 'f32']: defaults.__dict__['float_'] = jnp.float32 - defaults.__dict__['ti_float'] = ti.float32 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float32 elif dtype in [jnp.float64, 'float64', 'f64']: defaults.__dict__['float_'] = jnp.float64 - defaults.__dict__['ti_float'] = ti.float64 + if ti is not None: + defaults.__dict__['ti_float'] = ti.float64 else: raise NotImplementedError @@ -448,16 +451,20 @@ def set_int(dtype: type): """ if dtype in [jnp.int8, 'int8', 'i8']: defaults.__dict__['int_'] = jnp.int8 - defaults.__dict__['ti_int'] = ti.int8 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int8 elif dtype in [jnp.int16, 'int16', 'i16']: defaults.__dict__['int_'] = jnp.int16 - defaults.__dict__['ti_int'] = ti.int16 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int16 elif dtype in [jnp.int32, 'int32', 'i32']: defaults.__dict__['int_'] = jnp.int32 - defaults.__dict__['ti_int'] = ti.int32 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int32 elif dtype in [jnp.int64, 'int64', 'i64']: defaults.__dict__['int_'] = jnp.int64 - defaults.__dict__['ti_int'] = ti.int64 + if ti is not None: + defaults.__dict__['ti_int'] = ti.int64 else: raise NotImplementedError diff --git a/brainpy/_src/math/event/__init__.py b/brainpy/_src/math/event/__init__.py index 631129558..e61dc10cf 100644 --- a/brainpy/_src/math/event/__init__.py +++ b/brainpy/_src/math/event/__init__.py @@ -1,4 +1,3 @@ -from ._info_collection import * from ._csr_matvec import * diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 6e03be463..f4f23fa93 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -10,27 +10,19 @@ """ -from functools import partial from typing import Union, Tuple import jax import jax.numpy as jnp -import numba import numpy as np -from jax.core import ShapedArray, Primitive -from jax.interpreters import ad, xla -from jax.lib import xla_client +from jax.interpreters import ad -from brainpy._src.dependency_check import (import_brainpylib_gpu_ops) from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching, - XLACustomOp) -from brainpy._src.math.sparse._csr_mv import csrmv_brainpylib as normal_csrmv +from brainpy._src.math.op_register import XLACustomOp from brainpy._src.math.sparse._csr_mv import raw_csrmv_taichi as normal_csrmv_taichi from brainpy._src.math.sparse._utils import csr_to_coo -from brainpy.errors import GPUOperatorNotFound +from brainpy.errors import PackageMissingError __all__ = [ 'csrmv' @@ -81,535 +73,6 @@ def csrmv( return csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose) -### BRAINPYLIB ### - -def csrmv_brainpylib( - data: Union[float, jax.Array], - indices: jax.Array, - indptr: jax.Array, - events: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False -) -> jax.Array: - """Product of a sparse CSR matrix and a dense event vector. - - This function supports JAX transformations, including `jit()`, `grad()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - data: ndarray, float - An array of shape ``(nse,)``. - indices: ndarray - An array of shape ``(nse,)``. - indptr: ndarray - An array of shape ``(shape[0] + 1,)`` and dtype ``indices.dtype``. - events: ndarray - An array of shape ``(shape[0] if transpose else shape[1],)`` - and dtype ``data.dtype``. - shape: tuple - A length-2 tuple representing the matrix shape. - transpose: bool - A boolean specifying whether to transpose the sparse matrix - before computing. - If ``transpose=True``, the operator will compute based on the - event-driven property of the ``events`` vector. - - Returns - ------- - y : Array - The array of shape ``(shape[1] if transpose else shape[0],)`` representing - the matrix vector product. - """ - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - events = as_jax(events) - # checking - data = jnp.atleast_1d(data) - if np.ndim(data) == 1: - if data.shape[0] not in [1, indices.shape[0]]: - raise ValueError('The size of data should be 1 or be consistent with indices.' - f'But we got {data.shape} != {indices.shape}, {data.shape} != 1.') - else: - raise ValueError('data should be a scalar or 1D vector. ' - f'But we got {np.ndim(data)}-D array.') - if np.ndim(indices) != 1: - raise ValueError('indices should be a 1D vector with integer type.') - if np.ndim(indptr) != 1: - raise ValueError('indptr should be a 1D vector with integer type.') - if indices.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError('indices should be a 1D vector with int32 or int64 type.') - if indptr.dtype not in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64]: - raise ValueError('indptr should be a 1D vector with int32 or int64 type.') - if np.ndim(events) != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - # computing - return event_csr_matvec_p.bind(data, indices, indptr, events, shape=shape, transpose=transpose) - - -# ---------------------------------------------------------- -# event csr matvec -# ---------------------------------------------------------- - -# operator for `event_csr_matvec` batching rule -# -------- - -def _batch_event_csr_matvec_abstract( - values, indices, indptr, events, *, batch_size, shape, transpose=False -): - return ShapedArray(dtype=values.dtype, shape=(batch_size, shape[1] if transpose else shape[0])) - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_csr_matvec_transpose_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, batch_size, shape, _ = ins - batch_size = batch_size[()] - event_batch_dim = events.shape[0] - indices_batch_dim = indices.shape[0] - indptr_batch_dim = indptr.shape[0] - values_batch_dim = values.shape[0] - - if values.shape[1] == 1: # homogeneous value - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - values_bi = bi % values_batch_dim - for row_i in range(shape[0]): - if events[event_bi, row_i]: - value = values[values_bi, 0] - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - res_val[bi, col_i] += value - - else: # heterogeneous values - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - for row_i in range(shape[0]): - if events[event_bi, row_i]: - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - res_val[bi, col_i] += values[value_bi, j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_csr_matvec_numba_imp(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, batch_size, shape, transpose = ins - batch_size = batch_size[()] - event_batch_dim = events.shape[0] - indices_batch_dim = indices.shape[0] - indptr_batch_dim = indptr.shape[0] - values_batch_dim = values.shape[0] - - if values.shape[1] == 1: # homogeneous value - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - value = values[value_bi, 0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - if events[event_bi, col_i]: - r += value - res_val[bi, row_i] = r - - else: # heterogeneous values - for bi in numba.prange(batch_size): - event_bi = bi % event_batch_dim - indptr_bi = bi % indptr_batch_dim - indices_bi = bi % indices_batch_dim - value_bi = bi % values_batch_dim - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[indptr_bi, row_i], indptr[indptr_bi, row_i + 1]): - col_i = indices[indices_bi, j] - if events[event_bi, col_i]: - r += values[value_bi, j] - res_val[bi, row_i] = r - - -def _batch_event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, - batch_size, shape, transpose): - inputs = (values, indices, indptr, events) - description = dict(batch_size=batch_size, shape=shape, transpose=transpose) - if transpose: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _batch_event_csr_matvec_transpose_numba_imp, - _batch_event_csr_matvec_abstract, - False, - inputs=inputs, - description=description - ) - else: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _batch_event_csr_matvec_numba_imp, - _batch_event_csr_matvec_abstract, - False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, - name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _batch_event_csr_matvec_gpu_translation(c, values, indices, indptr, events, *, - batch_size, shape, transpose): - pass - - -def _batch_event_csr_matvec_jvp_values(values_dot, values, indices, indptr, events, *, - batch_size, shape, transpose): - return event_csr_matvec_batching_p.bind(values_dot, indices, indptr, events, - batch_size=batch_size, shape=shape, transpose=transpose) - - -def _batch_csr_matvec(values, indices, indptr, vectors, *, shape, transpose): - f = jax.vmap(partial(normal_csrmv, shape=shape, transpose=transpose), - in_axes=(0 if values.shape[0] > 1 else None, - 0 if indices.shape[0] > 1 else None, - 0 if indptr.shape[0] > 1 else None, - 0 if vectors.shape[0] > 1 else None)) - return f(values if values.shape[0] > 1 else values[0], - indices if indices.shape[0] > 1 else indices[0], - indptr if indptr.shape[0] > 1 else indptr[0], - vectors if vectors.shape[0] > 1 else vectors[0]) - - -def _batch_event_csr_matvec_jvp_events(events_dot, values, indices, indptr, events, *, - batch_size, shape, transpose): - return _batch_csr_matvec(values, indices, indptr, events_dot, - shape=shape, transpose=transpose) - - -def _batch_event_csr_matvec_transpose(ct, values, indices, indptr, events, *, - batch_size, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - if ad.is_undefined_primal(events): - ct_events = ( - ad.Zero(events.aval) if type(ct) is ad.Zero else - _batch_csr_matvec(ct, indices, indptr, values, - shape=shape, transpose=not transpose) - ) - return values, indices, indptr, ct_events - else: - if values.aval.shape[1] == 1: # scalar - temp = event_csr_matvec_batching_p.bind(jnp.ones((1, 1)), indices, indptr, events, - batch_size=batch_size, shape=shape, - transpose=transpose) - ct_values = jax.vmap(jnp.inner)(ct, temp) - else: # heterogeneous values - if type(ct) is ad.Zero: - ct_values = ad.Zero(values.aval) - else: - - def _f(ct, indices, indptr, events, *, transpose): - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[col] if transpose else events[col] * ct[row] - return ct_values - - f = jax.vmap(partial(_f, transpose=transpose), - in_axes=(0, - 0 if indices.shape[0] > 1 else None, - 0 if indptr.shape[0] > 1 else None, - 0 if events.shape[0] > 1 else None)) - ct_values = f(ct, - indices if indices.shape[0] > 1 else indices[0], - indptr if indptr.shape[0] > 1 else indptr[0], - events if events.shape[0] > 1 else events[0]) - return ct_values, indices, indptr, events - - -event_csr_matvec_batching_p = Primitive('event_csr_matvec_batching') -event_csr_matvec_batching_p.def_abstract_eval(_batch_event_csr_matvec_abstract) -event_csr_matvec_batching_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_batching_p)) -# xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation -ad.defjvp(event_csr_matvec_batching_p, _batch_event_csr_matvec_jvp_values, - None, None, _batch_event_csr_matvec_jvp_events) -ad.primitive_transposes[event_csr_matvec_batching_p] = _batch_event_csr_matvec_transpose - - -# operator for `event_csr_matvec` # -# ------------------------------- # - - -def _event_csr_matvec_abstract(values, indices, indptr, events, *, shape, transpose=False): - return ShapedArray(dtype=values.dtype, shape=(shape[1] if transpose else shape[0],)) - - -@numba.njit(fastmath=True) -def _event_csr_matvec_transpose_numba_imp1_bool(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - if values.shape[0] > 1: # heter - for row_i, event in enumerate(events): - if event: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values[j] - - else: # homo - values = values[0] - for row_i, event in enumerate(events): - if event: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values - - -@numba.njit(fastmath=True) -def _event_csr_matvec_transpose_numba_imp2(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - if values.shape[0] > 1: # heter - for row_i, event in enumerate(events): - if event > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values[j] - - else: # homo - values = values[0] - for row_i, event in enumerate(events): - if event > 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - res_val[col_i] += values - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _event_csr_matvec_numba_imp1_bool(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - - if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i]: - r += values[j] - res_val[row_i] = r - - else: # homo - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i]: - r += values - res_val[row_i] = r - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _event_csr_matvec_numba_imp2(outs, ins): - res_val = outs - res_val.fill(0) - values, indices, indptr, events, shape, _ = ins - - if values.shape[0] > 1: # heter - for row_i in range(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i] > 0.: - r += values[j] - res_val[row_i] = r - - else: # homo - values = values[0] - for row_i in numba.prange(shape[0]): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - col_i = indices[j] - if events[col_i] > 0.: - r += values - res_val[row_i] = r - - -def _event_csr_matvec_cpu_translation(c, values, indices, indptr, events, *, shape, transpose): - inputs = (values, indices, indptr, events) - event_type = c.get_shape(events) - description = dict(shape=shape, transpose=transpose) - if transpose: - if event_type.element_type() == jnp.bool_: - imp = _event_csr_matvec_transpose_numba_imp1_bool - else: - imp = _event_csr_matvec_transpose_numba_imp2 - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - imp, - abs_eval_fn=_event_csr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - else: - if event_type.element_type() == jnp.bool_: - imp = _event_csr_matvec_numba_imp1_bool - else: - imp = _event_csr_matvec_numba_imp2 - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - imp, - abs_eval_fn=_event_csr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _event_csr_matvec_gpu_translation(c, data, indices, indptr, vector, *, shape, transpose): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_csr_matvec_p.name) - - # shape checking - data_shape = c.get_shape(data) - indices_shape = c.get_shape(indices) - indptr_shape = c.get_shape(indptr) - vec_shape = c.get_shape(vector) - if data_shape.element_type() == jnp.float32: - ftype = b'_float' - elif data_shape.element_type() == jnp.float64: - ftype = b'_double' - else: - raise ValueError - assert indices_shape.element_type() == indptr_shape.element_type() - if indices_shape.element_type() == jnp.int32: - itype = b'_int' - elif indices_shape.element_type() == jnp.int64: - itype = b'_long' - else: - raise ValueError - data_name = b'_homo' if data_shape.dimensions() == (1,) else b'_heter' - tran_type = b'_transpose' if transpose else b'' - if vec_shape.element_type() == jnp.bool_: - vec_type = b'_bool' - else: - assert vec_shape.element_type() == data_shape.element_type() - vec_type = b'' - - # opaque - opaque = gpu_ops.build_double_size_descriptor(shape[0], shape[1]) - - # call - return xla_client.ops.CustomCallWithLayout( - c, - b'event_csrmv' + data_name + ftype + itype + vec_type + tran_type, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector)), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), - (shape[1] if transpose else shape[0],), - (0,)), - opaque=opaque, - ) - - -def _event_csr_matvec_batching_rule(args, axes, *, shape, transpose): - batch_size = 0 - args_processed = [] - for arg, axis in zip(args, axes): - if axis is None: - arg = jnp.expand_dims(jnp.atleast_1d(arg), 0) - else: - batch_size = arg.shape[axis] - if axis > 0: - arg = jnp.moveaxis(arg, axis, 0) - args_processed.append(arg) - - r = event_csr_matvec_batching_p.bind(*args_processed, - batch_size=batch_size, - shape=shape, - transpose=transpose) - return r, 0 - - -def _event_csr_matvec_jvp_values_brainpylib(values_dot, values, indices, indptr, events, *, shape, transpose): - return normal_csrmv(values_dot, indices, indptr, events, shape=shape, transpose=transpose) - - -def _event_csr_matvec_jvp_events_brainpylib(events_dot, values, indices, indptr, events, *, shape, transpose): - return normal_csrmv(values, indices, indptr, events_dot, shape=shape, transpose=transpose) - - -def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, *, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv(values, indices, indptr, ct, shape=shape, transpose=not transpose) - return values, indices, indptr, (ad.Zero(events) if type(ct) is ad.Zero else ct_events) - else: - if type(ct) is ad.Zero: - ct_values = ad.Zero(values) - else: - if values.aval.shape[0] == 1: # scalar - ct_values = csrmv_brainpylib(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose) - ct_values = jnp.inner(ct, ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[col] if transpose else events[col] * ct[row] - return ct_values, indices, indptr, events - - -event_csr_matvec_p = Primitive('event_csr_matvec') -event_csr_matvec_p.def_abstract_eval(_event_csr_matvec_abstract) -event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p)) -# xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation -# xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation -ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, - _event_csr_matvec_jvp_events_brainpylib) -ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib -register_general_batching(event_csr_matvec_p) - - -# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule - - -### TAICHI ### - def csrmv_taichi( data: Union[float, jax.Array], indices: jax.Array, @@ -691,298 +154,6 @@ def csrmv_taichi( return raw_csrmv_taichi(data, indices, indptr, events, shape=shape, transpose=transpose)[0] -# ------------- -# CPU operators -# ------------- - -# 1. The benchmarking shows that the performance of the following transpose -# kernels is maximized when using serialized mode -# 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable -# arguments, we have to define each kernel separately when the -# non-differentiable/non-jittable arguments are different. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i]: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += value - - -@ti.kernel -def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - if events[row_i] != 0.: - for j in range(indptr[row_i], indptr[row_i + 1]): - out[indices[j]] += values[j] - - -@ti.kernel -def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]]: - r += values[j] - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += value - out[row_i] = r - - -@ti.kernel -def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - # ti.loop_config(serialize=True) - for row_i in range(indptr.shape[0] - 1): - r = 0. - for j in range(indptr[row_i], indptr[row_i + 1]): - if events[indices[j]] != 0.: - r += values[j] - out[row_i] = r - - -# ------------- -# GPU operators -# ------------- - -# 1. GPU kernels are different from the CPU ones, since the GPU kernels need -# to use warp-level parallelism to achieve the best performance. - - -@ti.kernel -def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += value - j += 32 - - -# TODO -# It is important to note that the following warp-based kernels -# should be improved, since the atomic_add for each thread is not -# very efficient. Instead, the warp-level reduction primitive -# should be used. -# see ``warp_reduce_sum()`` function in tifunc.py. -# However, currently Taichi does not support general warp-level primitives. - - -@ti.kernel -def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - value = values[0] - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += value - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i]: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - if events[row_i] != 0.: - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - out[indices[j]] += values[j] - j += 32 - - -@ti.kernel -def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]]: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - -@ti.kernel -def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), - indices: ti.types.ndarray(ndim=1), - indptr: ti.types.ndarray(ndim=1), - events: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1)): - for i in range((indptr.shape[0] - 1) * 32): - row_i = i >> 5 - index = i & 31 - r = 0. - j = indptr[row_i] + index - end_index = indptr[row_i + 1] - while j < end_index: - if events[indices[j]] != 0.: - r += values[j] - j += 32 - out[row_i] += r # TODO: warp-level primitive - - def raw_csrmv_taichi( data: Union[float, jax.Array], indices: jax.Array, @@ -992,6 +163,9 @@ def raw_csrmv_taichi( shape: Tuple[int, int], transpose: bool = False ): + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + if transpose: if events.dtype == jnp.bool_: if data.shape[0] == 1: @@ -1025,65 +199,361 @@ def raw_csrmv_taichi( shape=shape) -def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) +if ti is not None: + + # ------------- + # CPU operators + # ------------- + + # 1. The benchmarking shows that the performance of the following transpose + # kernels is maximized when using serialized mode + # 2. Since our Taichi-JAX kernel does not support the non-differentiable/non-jittable + # arguments, we have to define each kernel separately when the + # non-differentiable/non-jittable arguments are different. + + @ti.kernel + def _event_csr_matvec_transpose_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + + @ti.kernel + def _event_csr_matvec_transpose_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i]: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + + @ti.kernel + def _event_csr_matvec_transpose_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += value + + + @ti.kernel + def _event_csr_matvec_transpose_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + if events[row_i] != 0.: + for j in range(indptr[row_i], indptr[row_i + 1]): + out[indices[j]] += values[j] + + + @ti.kernel + def _event_csr_matvec_bool_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += value + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_bool_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]]: + r += values[j] + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += value + out[row_i] = r + + + @ti.kernel + def _event_csr_matvec_heter_cpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + # ti.loop_config(serialize=True) + for row_i in range(indptr.shape[0] - 1): + r = 0. + for j in range(indptr[row_i], indptr[row_i + 1]): + if events[indices[j]] != 0.: + r += values[j] + out[row_i] = r + + + # ------------- + # GPU operators + # ------------- + + # 1. GPU kernels are different from the CPU ones, since the GPU kernels need + # to use warp-level parallelism to achieve the best performance. + + @ti.kernel + def _event_csr_matvec_transpose_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + + @ti.kernel + def _event_csr_matvec_transpose_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += value + j += 32 + + + # TODO + # It is important to note that the following warp-based kernels + # should be improved, since the atomic_add for each thread is not + # very efficient. Instead, the warp-level reduction primitive + # should be used. + # see ``warp_reduce_sum()`` function in tifunc.py. + # However, currently Taichi does not support general warp-level primitives. + + @ti.kernel + def _event_csr_matvec_bool_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + value = values[0] + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += value + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_transpose_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i]: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + + @ti.kernel + def _event_csr_matvec_transpose_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + if events[row_i] != 0.: + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + out[indices[j]] += values[j] + j += 32 + + + @ti.kernel + def _event_csr_matvec_bool_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]]: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive + + + @ti.kernel + def _event_csr_matvec_heter_gpu(values: ti.types.ndarray(ndim=1), + indices: ti.types.ndarray(ndim=1), + indptr: ti.types.ndarray(ndim=1), + events: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + for i in range((indptr.shape[0] - 1) * 32): + row_i = i >> 5 + index = i & 31 + r = 0. + j = indptr[row_i] + index + end_index = indptr[row_i + 1] + while j < end_index: + if events[indices[j]] != 0.: + r += values[j] + j += 32 + out[row_i] += r # TODO: warp-level primitive -def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): - return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + def _event_csr_matvec_jvp_values_taichi(val_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(val_dot, indices, indptr, events, shape=shape, transpose=transpose) -def _event_csr_matvec_transpose_taichi( - ct, values, indices, indptr, events, *, outs, transpose, shape -): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(events): - ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] - return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) - else: - if type(ct[0]) is ad.Zero: - ct_values = ad.Zero(values) + def _event_csr_matvec_jvp_events_taichi(evt_dot, values, indices, indptr, events, *, outs, transpose, shape): + return normal_csrmv_taichi(values, indices, indptr, evt_dot, shape=shape, transpose=transpose) + + + def _event_csr_matvec_transpose_taichi( + ct, values, indices, indptr, events, *, outs, transpose, shape + ): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + if ad.is_undefined_primal(events): + ct_events = normal_csrmv_taichi(values, indices, indptr, ct[0], shape=shape, transpose=transpose)[0] + return values, indices, indptr, (ad.Zero(events) if type(ct[0]) is ad.Zero else ct_events) else: - if values.aval.shape[0] == 1: # scalar - ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] - ct_values = jnp.inner(ct[0], ct_values) - else: # heterogeneous values - row, col = csr_to_coo(indices, indptr) - ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] - return ct_values, indices, indptr, events + if type(ct[0]) is ad.Zero: + ct_values = ad.Zero(values) + else: + if values.aval.shape[0] == 1: # scalar + ct_values = raw_csrmv_taichi(jnp.ones(1), indices, indptr, events, shape=shape, transpose=transpose)[0] + ct_values = jnp.inner(ct[0], ct_values) + else: # heterogeneous values + row, col = csr_to_coo(indices, indptr) + ct_values = events[row] * ct[0][col] if transpose else events[col] * ct[0][row] + return ct_values, indices, indptr, events -def _define_op(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) - prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) - return prim + def _define_op(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_csr_matvec_jvp_values_taichi, None, None, _event_csr_matvec_jvp_events_taichi) + prim.def_transpose_rule(_event_csr_matvec_transpose_taichi) + return prim -# transpose bool homo -_event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, - _event_csr_matvec_transpose_bool_homo_gpu) + # transpose bool homo + _event_csrmv_transpose_bool_homo_p = _define_op(_event_csr_matvec_transpose_bool_homo_cpu, + _event_csr_matvec_transpose_bool_homo_gpu) -# transpose homo -_event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, _event_csr_matvec_transpose_homo_gpu) + # transpose homo + _event_csrmv_transpose_homo_p = _define_op(_event_csr_matvec_transpose_homo_cpu, + _event_csr_matvec_transpose_homo_gpu) -# not transpose bool homo -_event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, _event_csr_matvec_bool_homo_gpu) + # not transpose bool homo + _event_csrmv_bool_homo_p = _define_op(_event_csr_matvec_bool_homo_cpu, + _event_csr_matvec_bool_homo_gpu) -# not transpose homo -_event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, _event_csr_matvec_homo_gpu) + # not transpose homo + _event_csrmv_homo_p = _define_op(_event_csr_matvec_homo_cpu, + _event_csr_matvec_homo_gpu) -# transpose bool heter -_event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, - _event_csr_matvec_transpose_bool_heter_gpu) + # transpose bool heter + _event_csrmv_transpose_bool_heter_p = _define_op(_event_csr_matvec_transpose_bool_heter_cpu, + _event_csr_matvec_transpose_bool_heter_gpu) -# transpose heter -_event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, - _event_csr_matvec_transpose_heter_gpu) + # transpose heter + _event_csrmv_transpose_heter_p = _define_op(_event_csr_matvec_transpose_heter_cpu, + _event_csr_matvec_transpose_heter_gpu) -# not transpose bool heter -_event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, _event_csr_matvec_bool_heter_gpu) + # not transpose bool heter + _event_csrmv_bool_heter_p = _define_op(_event_csr_matvec_bool_heter_cpu, + _event_csr_matvec_bool_heter_gpu) -# not transpose heter -_event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, _event_csr_matvec_heter_gpu) + # not transpose heter + _event_csrmv_heter_p = _define_op(_event_csr_matvec_heter_cpu, + _event_csr_matvec_heter_gpu) diff --git a/brainpy/_src/math/event/_info_collection.py b/brainpy/_src/math/event/_info_collection.py deleted file mode 100644 index 7bb043e3e..000000000 --- a/brainpy/_src/math/event/_info_collection.py +++ /dev/null @@ -1,198 +0,0 @@ -# -*- coding: utf-8 -*- - -from typing import Tuple, Union - -import jax -import numba -from jax import dtypes, numpy as jnp -from jax.core import ShapedArray -from jax.lib import xla_client - -from brainpy._src.dependency_check import import_brainpylib_gpu_ops -from brainpy._src.dependency_check import import_taichi -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register.base import XLACustomOp -from brainpy.errors import GPUOperatorNotFound - -ti = import_taichi() - -__all__ = [ - 'info' -] - - -def info(events: Union[Array, jax.Array]) -> Tuple[jax.Array, jax.Array]: - """Collect event information, including event indices, and event number. - - This function supports JAX transformations, including `jit()`, - `vmap()` and `pmap()`. - - Parameters - ---------- - events: jax.Array - The events. - - Returns - ------- - res: tuple - A tuple with two elements, denoting the event indices and the event number. - """ - events = as_jax(events) - if events.ndim != 1: - raise TypeError('Only support 1D boolean vector.') - return event_info_p(events) - - -def _batch_event_info_abstract(events): - assert events.ndim == 2 - # assert events.dtype == jnp.bool_ - event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) - event_num = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=(events.shape[0],)) - return event_ids, event_num - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _batch_event_info(outs, ins): - event_ids, event_num = outs - event_num.fill(0) - event_ids.fill(-1) - events = ins - for batch_idx in range(event_ids.shape[0]): - num = 0 - for i in range(event_ids.shape[1]): - if events[batch_idx, i]: - event_ids[batch_idx, num] = i - num += 1 - event_num[batch_idx] = num - - -@ti.kernel -def _batch_event_info_taichi(events: ti.types.ndarray(ndim=2), - event_ids: ti.types.ndarray(ndim=2), - event_num: ti.types.ndarray(ndim=1)): - for i, j in ti.grouped(ti.ndrange(event_ids.shape)): - event_ids[i, j] = -1 - for batch_idx in range(event_ids.shape[0]): - num = 0 - for i in range(event_ids.shape[1]): - if events[batch_idx, i]: - event_ids[batch_idx, num] = i - num += 1 - event_num[batch_idx] = num - - -def _batch_event_info_batching_rule(args, axes): - arg = jnp.moveaxis(args[0], axes[0], 0) - shape = arg.shape - arg = jnp.reshape(arg, (shape[0] * shape[1], shape[2])) - event_ids, event_num = batch_event_info_p(arg) - return ((jnp.reshape(event_ids, shape), jnp.reshape(event_num, shape[:2])), - (0, 0)) - - -def _event_info_gpu_translation(c, events): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_info_p.name) - - e_shape = c.get_shape(events).dimensions() - e_type = c.get_shape(events).element_type() - if len(e_shape) == 1: - event_size = e_shape[0] - batch_size = 1 - event_ids_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (event_size,), - (0,)) - else: - batch_size, event_size = e_shape - event_ids_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (batch_size, event_size), - (1, 0)) - event_num_shape = xla_client.Shape.array_shape(dtypes.canonicalize_dtype(int), - (batch_size,), - (0,)) - opaque = gpu_ops.build_nonzero_descriptor(event_size, batch_size) - - if e_type == jnp.bool_: - type_name = b'_bool' - elif e_type == jnp.int32: - type_name = b'_int' - elif e_type == jnp.int64: - type_name = b'_long' - elif e_type == jnp.float32: - type_name = b'_float' - elif e_type == jnp.float64: - type_name = b'_double' - else: - raise ValueError - - return xla_client.ops.CustomCallWithLayout( - c, - b'nonzero' + type_name, - operands=(events,), - operand_shapes_with_layout=(c.get_shape(events),), - shape_with_layout=xla_client.Shape.tuple_shape((event_ids_shape, event_num_shape)), - opaque=opaque, - ) - - -batch_event_info_p = XLACustomOp( - name='batched_event_info', - cpu_kernel=_batch_event_info_taichi, - gpu_kernel=_batch_event_info_taichi, - outs=_batch_event_info_abstract, -) -batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule) - - -def _event_info_abstract(events, **kwargs): - assert events.ndim == 1 - # assert events.dtype == jnp.bool_ - event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) - event_num = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=(1,)) - return event_ids, event_num - - -# TODO: first parallel evaluate the sub-sections, then serially event the sub-results. -@numba.njit(fastmath=True) -def _event_info(outs, ins): - event_ids, event_num = outs - event_num.fill(0) - event_ids.fill(-1) - events = ins - num = 0 - for i in range(event_ids.shape[0]): - if events[i]: - event_ids[num] = i - num += 1 - event_num[0] = num - - -@ti.kernel -def _event_info_taichi(events: ti.types.ndarray(ndim=1), - event_ids: ti.types.ndarray(ndim=1), - event_num: ti.types.ndarray(ndim=1)): - for i in range(event_ids.shape[0]): - event_ids[i] = -1 - num = 0 - for i in range(event_ids.shape[0]): - if events[i]: - event_ids[num] = i - num += 1 - event_num[0] = num - - -def _event_info_batching_rule(args, axes): - arg = jnp.moveaxis(args[0], axes[0], 0) - return (batch_event_info_p(arg), (0, 0)) - - -event_info_p = XLACustomOp( - name='event_info', - cpu_kernel=_event_info_taichi, - gpu_kernel=_event_info_taichi, - outs=_event_info_abstract, - # gpu_func_translation=_event_info_gpu_translation, -) -event_info_p.def_batching_rule(_event_info_batching_rule) diff --git a/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py b/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py deleted file mode 100644 index 74cc6b7f9..000000000 --- a/brainpy/_src/math/event/tests/event_info_VS_jax_operators.py +++ /dev/null @@ -1,275 +0,0 @@ -from time import time - -from jax import jit, vmap, numpy as jnp - -import brainpy.math as bm - - -def compare_argsort_and_sum(platform='cpu'): - """ - CPU - --- - - shape = (100, 10000) - brainpylib 0.1872694492340088 s - JAX argsort + sum 5.297466516494751 s - - shape = (100, 100000) - brainpylib 2.333505153656006 s - JAX argsort + sum 65.20281910896301 s - - shape = (1000, 10000) - brainpylib 2.0739688873291016 s - JAX argsort + sum 53.70602822303772 s - - shape = (10000, 1000) - brainpylib 1.7262670993804932 s - JAX argsort + sum 43.92174816131592 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.14670848846435547 s - JAX argsort + sum 1.001936435699463 s - - shape = (100, 1000000) - brainpylib 0.27660632133483887 s - JAX argsort + sum 16.390073776245117 s - - shape = (1000, 100000) - brainpylib 0.2619345188140869 s - JAX argsort + sum 9.715844869613647 s - - shape = (1000, 500000) - brainpylib 1.201209306716919 s - JAX argsort + sum 71.19761657714844 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: (jnp.argsort(events), jnp.sum(events)))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids2, event_num2 = jax_event_info(events) - assert jnp.allclose(event_num1, event_num2) - event_ids1.block_until_ready() - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a, b = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX argsort + sum {time() - t0} s') - - print() - - -def compare_argsort(platform='cpu'): - """ - - CPU - --- - - shape = (100, 10000) - brainpylib 0.19738531112670898 s - JAX argsort 5.301469087600708 s - - shape = (100, 100000) - brainpylib 2.3321938514709473 s - JAX argsort 65.13460850715637 s - - shape = (1000, 10000) - brainpylib 2.0956876277923584 s - JAX argsort 53.863110065460205 s - - shape = (10000, 1000) - brainpylib 1.7127799987792969 s - JAX argsort 44.05547475814819 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.1415419578552246 s - JAX argsort 0.9982438087463379 s - - shape = (100, 1000000) - brainpylib 0.3224947452545166 s - JAX argsort 16.504750967025757 s - - shape = (1000, 100000) - brainpylib 0.2781648635864258 s - JAX argsort 9.691488981246948 s - - shape = (1000, 500000) - brainpylib 1.2167487144470215 s - JAX argsort 71.68716263771057 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: jnp.argsort(events))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids1.block_until_ready() - event_ids2 = jax_event_info(events) - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX argsort {time() - t0} s') - - print() - - -def compare_where(platform='cpu'): - """ - - CPU - --- - - shape = (100, 10000) - brainpylib 0.20480966567993164 s - JAX where 0.7068588733673096 s - - shape = (100, 100000) - brainpylib 2.3373026847839355 s - JAX where 5.862265348434448 s - - shape = (1000, 10000) - brainpylib 2.105764865875244 s - JAX where 5.914586067199707 s - - shape = (10000, 1000) - brainpylib 1.724682331085205 s - JAX where 5.718563795089722 s - - GPU - --- - shape = (100, 100000) - brainpylib 0.15492558479309082 s - JAX where 0.3146538734436035 s - - shape = (100, 1000000) - brainpylib 0.3290700912475586 s - JAX where 1.7064015865325928 s - - shape = (1000, 100000) - brainpylib 0.2895216941833496 s - JAX where 1.6910102367401123 s - - shape = (1000, 500000) - brainpylib 1.173649787902832 s - JAX where 7.868000268936157 s - - """ - - bm.set_platform(platform) - - rng = bm.random.RandomState(123) - bp_event_info = jit(vmap(bm.event.info)) - jax_event_info = jit(vmap(lambda events: jnp.where(events, size=events.shape[0]))) - - if platform == 'cpu': - all_shapes = [ - (100, 10000), - (100, 100000), - (1000, 10000), - (10000, 1000), - ] - else: - all_shapes = [ - (100, 100000), - (100, 1000000), - (1000, 100000), - (1000, 500000), - ] - - for shape in all_shapes: - print(f'shape = {shape}') - - events = rng.random(shape).value < 0.1 - event_ids1, event_num1 = bp_event_info(events) - event_ids1.block_until_ready() - event_ids2, = jax_event_info(events) - event_ids2.block_until_ready() - - t0 = time() - for _ in range(100): - a, b = bp_event_info(events) - r = a.block_until_ready() - print(f'brainpylib {time() - t0} s') - - t0 = time() - for _ in range(100): - a, = jax_event_info(events) - r = a.block_until_ready() - print(f'JAX where {time() - t0} s') - - print() - - -if __name__ == '__main__': - # compare_argsort_and_sum('cpu') - # compare_argsort_and_sum('gpu') - # compare_argsort('cpu') - compare_argsort('gpu') - # compare_where('cpu') - # compare_where('gpu') diff --git a/brainpy/_src/math/event/tests/test_event_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py index e0f38490f..1641c9db9 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv.py @@ -4,11 +4,18 @@ from functools import partial import jax +import pytest from absl.testing import parameterized import brainpy as bp import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi() is None: + pytest.skip('no taichi', allow_module_level=True) + + seed = 1234 diff --git a/brainpy/_src/math/event/tests/test_event_csrmv_old.py b/brainpy/_src/math/event/tests/test_event_csrmv_old.py index 31a6527a2..fcb25a89c 100644 --- a/brainpy/_src/math/event/tests/test_event_csrmv_old.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv_old.py @@ -4,19 +4,13 @@ from functools import partial import jax -from absl.testing import parameterized +import pytest import brainpy as bp import brainpy.math as bm -import platform -import pytest pytest.skip('Old implementation.', allow_module_level=True) -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) - brainpylib_csr_matvec = partial(bm.event.csrmv, method='brainpylib') taichi_csr_matvec = partial(bm.event.csrmv, method='taichi') diff --git a/brainpy/_src/math/event/tests/test_info.py b/brainpy/_src/math/event/tests/test_info.py deleted file mode 100644 index c326b0f76..000000000 --- a/brainpy/_src/math/event/tests/test_info.py +++ /dev/null @@ -1,62 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax.numpy as jnp -import unittest - -import brainpy.math as bm -from jax import vmap - -import pytest - - -class Test_event_info(unittest.TestCase): - def __init__(self, *args, platform='cpu', **kwargs): - super(Test_event_info, self).__init__(*args, **kwargs) - - print() - bm.set_platform(platform) - - def _base_test(self, length): - print(f'{self._base_test.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random(length)) < 0.1 - event_ids, event_num = bm.event.info(events) - self.assertTrue(jnp.allclose(jnp.sum(events, keepdims=True), event_num)) - - bm.clear_buffer_memory() - - def _base_vmap(self, length): - print(f'{self._base_vmap.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, length))) < 0.1 - event_ids, event_num = vmap(bm.event.info)(events) - self.assertTrue(jnp.allclose(jnp.sum(events, axis=-1), event_num)) - - bm.clear_buffer_memory() - - def _base_vmap_vmap(self, length): - print(f'{self._base_vmap_vmap.__name__}: length = {length}') - - rng = bm.random.RandomState() - events = bm.as_jax(rng.random((10, length))) < 0.1 - event_ids, event_num = vmap(vmap(bm.event.info))(events) - self.assertTrue(jnp.allclose(jnp.sum(events, axis=-1), event_num)) - - bm.clear_buffer_memory() - - def test(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - def test_vmap(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - def test_vmap_vmap(self): - for length in [1, 3, 8, 10, 100, 200, 500, 1000, 10000, 100000]: - self._base_test(length) - - - diff --git a/brainpy/_src/math/event/tests/test_info_gpu.py b/brainpy/_src/math/event/tests/test_info_gpu.py deleted file mode 100644 index 55bdd15cd..000000000 --- a/brainpy/_src/math/event/tests/test_info_gpu.py +++ /dev/null @@ -1,14 +0,0 @@ -# -*- coding: utf-8 -*- - -import jax -import pytest - -import test_info - -if jax.default_backend() != 'gpu': - pytest.skip("No gpu available.", allow_module_level=True) - - -class Test_event_info_GPU(test_info.Test_event_info): - def __init__(self, *args, **kwargs): - super(Test_event_info_GPU, self).__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/index_tricks.py b/brainpy/_src/math/index_tricks.py deleted file mode 100644 index 6c71b4b06..000000000 --- a/brainpy/_src/math/index_tricks.py +++ /dev/null @@ -1,305 +0,0 @@ -# -*- coding: utf-8 -*- - -import abc - -from jax import core -from .compat_numpy import arange, array, concatenate, expand_dims, linspace, meshgrid, stack, transpose -import numpy as np - -__all__ = ["c_", "index_exp", "mgrid", "ogrid", "r_", "s_"] - - -def _make_1d_grid_from_slice(s: slice, op_name: str): - start = core.concrete_or_error(None, s.start, - f"slice start of jnp.{op_name}") or 0 - stop = core.concrete_or_error(None, s.stop, - f"slice stop of jnp.{op_name}") - step = core.concrete_or_error(None, s.step, - f"slice step of jnp.{op_name}") or 1 - if np.iscomplex(step): - newobj = linspace(start, stop, int(abs(step))) - else: - newobj = arange(start, stop, step) - - return newobj - - -class _IndexGrid(abc.ABC): - """Creates multi-dimensional grids of indices.""" - sparse: bool - op_name: str - - def __getitem__(self, key): - if isinstance(key, slice): - return _make_1d_grid_from_slice(key, op_name=self.op_name) - output = (_make_1d_grid_from_slice(k, op_name=self.op_name) for k in key) - output = meshgrid(*output, indexing='ij', sparse=self.sparse) - return output if self.sparse else stack(output, 0) - - -class _Mgrid(_IndexGrid): - """Return dense multi-dimensional "meshgrid". - - LAX-backend implementation of :obj:`numpy.mgrid`. This is a convenience wrapper for - functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=False``. - - See Also: - jnp.ogrid: open/sparse version of jnp.mgrid - - Examples: - Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - - >>> import brainpy.math as bm - >>> bm.mgrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) - - Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - - >>> bm.mgrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) - - Multiple slices can be used to create broadcasted grids of indices: - - >>> bm.mgrid[:2, :3] - DeviceArray([[[0, 0, 0], - [1, 1, 1]], - [[0, 1, 2], - [0, 1, 2]]], dtype=int32) - """ - sparse = False - op_name = "mgrid" - - -mgrid = _Mgrid() - - -class _Ogrid(_IndexGrid): - """Return open multi-dimensional "meshgrid". - - LAX-backend implementation of :obj:`numpy.ogrid`. This is a convenience wrapper for - functionality provided by :func:`jax.numpy.meshgrid` with ``sparse=True``. - - See Also: - jnp.mgrid: dense version of jnp.ogrid - - Examples: - Pass ``[start:stop:step]`` to generate values similar to :func:`jax.numpy.arange`: - - >>> bm.ogrid[0:4:1] - DeviceArray([0, 1, 2, 3], dtype=int32) - - Passing an imaginary step generates values similar to :func:`jax.numpy.linspace`: - - >>> bm.ogrid[0:1:4j] - DeviceArray([0. , 0.33333334, 0.6666667 , 1. ], dtype=float32) - - Multiple slices can be used to create sparse grids of indices: - - >>> bm.ogrid[:2, :3] - [DeviceArray([[0], - [1]], dtype=int32), - DeviceArray([[0, 1, 2]], dtype=int32)] - """ - sparse = True - op_name = "ogrid" - - -ogrid = _Ogrid() - - -class _AxisConcat(abc.ABC): - """Concatenates slices, scalars and array-like objects along a given axis.""" - axis: int - ndmin: int - trans1d: int - op_name: str - - def __getitem__(self, key): - if not isinstance(key, tuple): - key = (key,) - - params = [self.axis, self.ndmin, self.trans1d, -1] - - if isinstance(key[0], str): - # split off the directive - directive, *key = key # pytype: disable=bad-unpacking - # check two special cases: matrix directives - if directive == "r": - params[-1] = 0 - elif directive == "c": - params[-1] = 1 - else: - vec = directive.split(",") - k = len(vec) - if k < 4: - vec += params[k:] - else: - # ignore everything after the first three comma-separated ints - vec = vec[:3] + params[-1] - try: - params = list(map(int, vec)) - except ValueError as err: - raise ValueError( - "could not understand directive {!r}".format(directive) - ) from err - - axis, ndmin, trans1d, matrix = params - - output = [] - for item in key: - if isinstance(item, slice): - newobj = _make_1d_grid_from_slice(item, op_name=self.op_name) - elif isinstance(item, str): - raise ValueError("string directive must be placed at the beginning") - else: - newobj = item - - newobj = array(newobj, copy=False, ndmin=ndmin) - - if trans1d != -1 and ndmin - np.ndim(item) > 0: - shape_obj = list(range(ndmin)) - # Calculate number of left shifts, with overflow protection by mod - num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin - shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts]) - - newobj = transpose(newobj, shape_obj) - - output.append(newobj) - - res = concatenate(tuple(output), axis=axis) - - if matrix != -1 and res.ndim == 1: - # insert 2nd dim at axis 0 or 1 - res = expand_dims(res, matrix) - - return res - - def __len__(self): - return 0 - - -class RClass(_AxisConcat): - """Concatenate slices, scalars and array-like objects along the first axis. - - LAX-backend implementation of :obj:`numpy.r_`. - - See Also: - ``jnp.c_``: Concatenates slices, scalars and array-like objects along the last axis. - - Examples: - Passing slices in the form ``[start:stop:step]`` generates ``jnp.arange`` objects: - - >>> bm.r_[-1:5:1, 0, 0, bm.array([1,2,3])] - DeviceArray([-1, 0, 1, 2, 3, 4, 0, 0, 1, 2, 3], dtype=int32) - - An imaginary value for ``step`` will create a ``jnp.linspace`` object instead, - which includes the right endpoint: - - >>> bm.r_[-1:1:6j, 0, bm.array([1,2,3])] - DeviceArray([-1. , -0.6 , -0.20000002, 0.20000005, - 0.6 , 1. , 0. , 1. , - 2. , 3. ], dtype=float32) - - Use a string directive of the form ``"axis,dims,trans1d"`` as the first argument to - specify concatenation axis, minimum number of dimensions, and the position of the - upgraded array's original dimensions in the resulting array's shape tuple: - - >>> bm.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, 2D output - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) - - >>> bm.r_['0,2,0', [1,2,3], [4,5,6]] # push last input axis to the front - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - Negative values for ``trans1d`` offset the last axis towards the start - of the shape tuple: - - >>> bm.r_['0,2,-2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs - to create an array with an extra row or column axis, respectively: - - >>> bm.r_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3, 4, 5, 6]], dtype=int32) - - >>> bm.r_['c',[1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - For higher-dimensional inputs (``dim >= 2``), both directives ``"r"`` and ``"c"`` - give the same result. - """ - axis = 0 - ndmin = 1 - trans1d = -1 - op_name = "r_" - - -r_ = RClass() - - -class CClass(_AxisConcat): - """Concatenate slices, scalars and array-like objects along the last axis. - - LAX-backend implementation of :obj:`numpy.c_`. - - See Also: - ``jnp.r_``: Concatenates slices, scalars and array-like objects along the first axis. - - Examples: - - >>> a = bm.arange(6).reshape((2,3)) - >>> bm.c_[a,a] - DeviceArray([[0, 1, 2, 0, 1, 2], - [3, 4, 5, 3, 4, 5]], dtype=int32) - - Use a string directive of the form ``"axis:dims:trans1d"`` as the first argument to specify - concatenation axis, minimum number of dimensions, and the position of the upgraded array's - original dimensions in the resulting array's shape tuple: - - >>> bm.c_['0,2', [1,2,3], [4,5,6]] - DeviceArray([[1], - [2], - [3], - [4], - [5], - [6]], dtype=int32) - - >>> bm.c_['0,2,-1', [1,2,3], [4,5,6]] - DeviceArray([[1, 2, 3], - [4, 5, 6]], dtype=int32) - - Use the special directives ``"r"`` or ``"c"`` as the first argument on flat inputs - to create an array with inputs stacked along the last axis: - - >>> jnp.c_['r',[1,2,3], [4,5,6]] - DeviceArray([[1, 4], - [2, 5], - [3, 6]], dtype=int32) - """ - axis = -1 - ndmin = 2 - trans1d = 0 - op_name = "c_" - - -c_ = CClass() - -s_ = np.s_ - -index_exp = np.index_exp diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 3671755a9..33ee9f1b5 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -1,23 +1,15 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Tuple, Optional import jax import numpy as np -from jax import numpy as jnp, dtypes -from jax.core import ShapedArray, Primitive -from jax.interpreters import xla, ad -from jax.lib import xla_client +from jax import numpy as jnp -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.jitconn._matvec import (mv_prob_homo_p, - mv_prob_uniform_p, - mv_prob_normal_p, - mv_prob_homo, +from brainpy._src.math.jitconn._matvec import (mv_prob_homo, mv_prob_uniform, - mv_prob_normal, _general_checking, raw_mv_prob_homo, raw_mv_prob_uniform, @@ -27,9 +19,8 @@ _mv_prob_normal_transpose, _reverse) from brainpy._src.math.ndarray import _get_dtype -from brainpy._src.math.op_register import register_general_batching, XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.op_register import XLACustomOp +from brainpy.errors import PackageMissingError ti = import_taichi() @@ -50,7 +41,9 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, shape=shape, transpose=transpose, + return event_mv_prob_homo_taichi(events, weight, conn_prob, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) @@ -68,7 +61,9 @@ def event_mv_prob_uniform( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, + return event_mv_prob_uniform_taichi(events, w_low, w_high, conn_prob, seed, + shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) @@ -86,651 +81,11 @@ def event_mv_prob_normal( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel) - - -### BRAINPYLIB ### - -def event_mv_prob_homo_brainpylib( - events: jax.Array, - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - weight = jnp.atleast_1d(as_jax(weight)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - return r - - -event_mv_prob_homo_brainpylib.__doc__ = mv_prob_homo.__doc__ - - -def event_mv_prob_uniform_brainpylib( - events: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, + return event_mv_prob_uniform_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_uniform_brainpylib.__doc__ = mv_prob_uniform.__doc__ - - -def event_mv_prob_normal_brainpylib( - events: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - events = as_jax(events) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -event_mv_prob_normal_brainpylib.__doc__ = mv_prob_normal.__doc__ - - -def _event_matvec_prob_homo_abstract( - events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - out = ShapedArray(dtype=weight.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_homo_cpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_homo' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_homo' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_homo_gpu_translation( - c, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_homo_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1], ) - - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_homo_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_homo_v2' + type_name + event_type - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, weight, clen, seed = primals - event_dot, weight_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_homo_p.bind(events, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(weight_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(event_dot) is ad.Zero: - raise ValueError - dr = mv_prob_homo_p.bind(event_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(event_dot) is ad.Zero: - dr = mv_prob_homo_p.bind(events, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - dr = mv_prob_homo_p.bind(event_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, dr - - -def _event_matvec_prob_homo_transpose( - ct, events, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -event_mv_prob_homo_p = Primitive('event_mv_prob_homo') -event_mv_prob_homo_p.multiple_results = True -event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) -event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation -ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp -ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose -register_general_batching(event_mv_prob_homo_p) - - -def _event_matvec_prob_uniform_abstract( - events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_low.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _event_matvec_prob_uniform_cpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_uniform' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_uniform' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_uniform_gpu_translation( - c, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_uniform_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_uniform_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_uniform_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _event_matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_low, w_high, clen, seed = primals - events_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_uniform_p.bind(events, - w_low, - w_high, - clen, - seed, - shape=shape, - outdim_parallel=outdim_parallel, - transpose=transpose) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(events_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_uniform_transpose( - ct, events, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -event_mv_prob_uniform_p = Primitive('event_mv_prob_uniform') -event_mv_prob_uniform_p.multiple_results = True -event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) -event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation -register_general_batching(event_mv_prob_uniform_p) -ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp -ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose - - -def _event_matvec_prob_normal_abstract( - events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(events) in [jnp.bool_, jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if events.ndim != 1: - raise ValueError('events should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - assert w_mu.dtype == w_sigma.dtype - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - if transpose: - if events.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({events.shape[0]},) @ mat {shape}.') - else: - if events.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({events.shape[0]},).') - - out = ShapedArray(dtype=w_mu.dtype, shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _get_types(event_shape): - event_type = event_shape.element_type() - if event_type == jnp.bool_: - event_type = b'_bool' - out_dtype = dtypes.canonicalize_dtype(float) - elif event_type == jnp.float32: - event_type = b'_float' - out_dtype = event_shape.element_type() - elif event_type == jnp.float64: - event_type = b'_double' - out_dtype = event_shape.element_type() - else: - raise TypeError - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - return out_dtype, event_type, type_name - - -def _event_matvec_prob_normal_cpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - if outdim_parallel: - fn = b'cpu_event_matvec_prob_normal' + type_name + event_type - else: - fn = b'cpu_event_matvec_atomic_prob_normal' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _event_matvec_prob_normal_gpu_translation( - c, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(event_mv_prob_normal_p.name) - - out_dtype, event_type, type_name = _get_types(c.get_shape(events)) - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - if outdim_parallel: - fn = b'gpu_jit_event_csrmv_prob_normal_v2' + type_name + event_type - else: - fn = b'gpu_jit_event_csrmv_atomic_prob_normal_v2' + type_name + event_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(events, w_mu, w_sigma, clen, seed), - operand_shapes_with_layout=(c.get_shape(events), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - + outdim_parallel=outdim_parallel) -def _event_matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - events, w_mu, w_sigma, clen, seed = primals - events_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = event_mv_prob_normal_p.bind(events, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(events_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _event_matvec_prob_normal_transpose( - ct, events, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(events) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -event_mv_prob_normal_p = Primitive('event_mv_prob_normal') -event_mv_prob_normal_p.multiple_results = True -event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) -event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation -register_general_batching(event_mv_prob_normal_p) -ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp -ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose - - -### TAICHI ### def event_mv_prob_homo_taichi( events: jax.Array, @@ -790,6 +145,9 @@ def event_mv_prob_homo_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + events = as_jax(events) if isinstance(weight, float): weight = as_jax(weight) weight = jnp.atleast_1d(as_jax(weight)) @@ -799,8 +157,10 @@ def event_mv_prob_homo_taichi( with jax.ensure_compile_time_eval(): seed = np.random.randint(0, int(1e8), 1) seed = jnp.atleast_1d(jnp.asarray(seed, dtype=jnp.uint32)) - return raw_event_mv_prob_homo(events, weight, conn_len, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel)[0] + return raw_event_mv_prob_homo(events, weight, conn_len, seed, + shape=shape, + transpose=transpose, + outdim_parallel=outdim_parallel)[0] def event_mv_prob_uniform_taichi( @@ -864,6 +224,9 @@ def event_mv_prob_uniform_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + events = as_jax(events) if isinstance(w_low, float): w_low = as_jax(w_low) if isinstance(w_high, float): w_high = as_jax(w_high) @@ -940,6 +303,9 @@ def event_mv_prob_normal_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + events = as_jax(events) if isinstance(w_mu, float): w_mu = as_jax(w_mu) if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma) @@ -955,1034 +321,1033 @@ def event_mv_prob_normal_taichi( transpose=transpose, outdim_parallel=outdim_parallel)[0] -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 +if ti is not None: + from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + + # ------------- + # CPU function + # ------------- + # For each non-zero event value, it generates a random key using a + # function lfsr88_key and then uses this key to compute random integers + # and update the out array based on the computed indices and weight. + # + # The function is likely designed to be parallelized. + + @ti.kernel + def _event_mv_prob_homo_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col]: + r += weight0 key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: + i_col += inc + out[i_row] = r + + + # ------------- + # GPU function + # ------------- + # Contrary to the CPU functions, for each column, + # this function will 32 threads (one warp) to make + # the just-in-time random generation parallelized. + + @ti.kernel + def _event_mv_prob_homo_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: - index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison without if else + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + index = i & 31 + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _reverse(shape): - return shape[::-1] - - -# ------------- -# CPU function -# ------------- -# For each non-zero event value, it generates a random key using a -# function lfsr88_key and then uses this key to compute random integers -# and update the out array based on the computed indices and weight. -# -# The function is likely designed to be parallelized. - - -@ti.kernel -def _event_mv_prob_homo_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - out[i_row] += weight0 + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _reverse(shape): + return shape[::-1] + + + # ------------- + # CPU function + # ------------- + # For each non-zero event value, it generates a random key using a + # function lfsr88_key and then uses this key to compute random integers + # and update the out array based on the computed indices and weight. + # + # The function is likely designed to be parallelized. + + @ti.kernel + def _event_mv_prob_homo_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): if events[i_col] != 0.: - r += weight0 - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -# ------------- -# GPU function -# ------------- -# Contrary to the CPU functions, for each column, -# this function will 32 threads (one warp) to make -# the just-in-time random generation parallelized. - - -@ti.kernel -def _event_mv_prob_homo_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + if events[i_col] != 0.: + r += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + + # ------------- + # GPU function + # ------------- + # Contrary to the CPU functions, for each column, + # this function will 32 threads (one warp) to make + # the just-in-time random generation parallelized. + + @ti.kernel + def _event_mv_prob_homo_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + out[i_row] += weight0 + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_homo_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 + i_col += inc + while i_col < end_col: + r += weight0 * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + i_col += inc + out[i_row] += r # TODO: warp-level reduction -@ti.kernel -def _event_mv_prob_homo_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += weight0 * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction + def _event_mv_prob_homo_jvp_events( + evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(evt_dot, weight, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _event_mv_prob_homo_jvp_events( - evt_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(evt_dot, weight, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _event_mv_prob_homo_jvp_weight( + w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(events, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _event_mv_prob_homo_jvp_weight( - w_dot, events, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(events, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + + + def raw_event_mv_prob_homo( + events: jax.Array, + weight: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_outdim_parallel_bool_p + else: + prim = _event_mv_prob_homo_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_homo_bool_p + else: + prim = _event_mv_prob_homo_p + + return prim(events, + weight, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_homo_jvp_events, + _event_mv_prob_homo_jvp_weight, + None, + None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_bool_cpu, + gpu_kernel=_event_mv_prob_homo_bool_gpu + ) -def _event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.bool_, jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu + ) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( + cpu_kernel=_event_mv_prob_homo_cpu, + gpu_kernel=_event_mv_prob_homo_gpu + ) -def raw_event_mv_prob_homo( - events: jax.Array, - weight: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, weight) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_outdim_parallel_bool_p - else: - prim = _event_mv_prob_homo_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_homo_bool_p - else: - prim = _event_mv_prob_homo_p - - return prim(events, - weight, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=weight.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_homo_jvp_events, - _event_mv_prob_homo_jvp_weight, - None, - None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_homo_outdim_parallel_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_homo_bool_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_bool_cpu, - gpu_kernel=_event_mv_prob_homo_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_homo_outdim_parallel_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_homo_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_homo_p = _define_event_mv_prob_homo_prim( - cpu_kernel=_event_mv_prob_homo_cpu, - gpu_kernel=_event_mv_prob_homo_gpu -) - - -@ti.kernel -def _event_mv_prob_uniform_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + @ti.kernel + def _event_mv_prob_uniform_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + if events[i_col]: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_uniform_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_uniform_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_uniform_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + if events[i_col] != 0.: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) + i_col += inc + out[i_row] = r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_uniform_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_uniform_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_uniform_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_uniform_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_uniform_jvp_events( - evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_uniform_jvp_w_low( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _event_mv_prob_uniform_jvp_events( + evt_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(evt_dot, w_low, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_uniform_jvp_w_low( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_dot, w_high, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_uniform_jvp_w_high( + w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def raw_event_mv_prob_uniform( + events: jax.Array, + w_low: jax.Array, # vector with size 1 + w_high: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_outdim_parallel_bool_p + else: + prim = _event_mv_prob_uniform_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_uniform_bool_p + else: + prim = _event_mv_prob_uniform_p + + return prim(events, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_uniform_jvp_events, + _event_mv_prob_uniform_jvp_w_low, + _event_mv_prob_uniform_jvp_w_high, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_bool_cpu, + gpu_kernel=_event_mv_prob_uniform_bool_gpu + ) -def _event_mv_prob_uniform_jvp_w_high( - w_dot, events, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(events, w_low, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu + ) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( + cpu_kernel=_event_mv_prob_uniform_cpu, + gpu_kernel=_event_mv_prob_uniform_gpu + ) -def raw_event_mv_prob_uniform( - events: jax.Array, - w_low: jax.Array, # vector with size 1 - w_high: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_outdim_parallel_bool_p - else: - prim = _event_mv_prob_uniform_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_uniform_bool_p - else: - prim = _event_mv_prob_uniform_p - - return prim(events, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_low.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_uniform_jvp_events, - _event_mv_prob_uniform_jvp_w_low, - _event_mv_prob_uniform_jvp_w_high, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_uniform_bool_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_bool_cpu, - gpu_kernel=_event_mv_prob_uniform_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_uniform_outdim_parallel_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_uniform_p = _define_event_mv_prob_uniform_prim( - cpu_kernel=_event_mv_prob_uniform_cpu, - gpu_kernel=_event_mv_prob_uniform_gpu -) - - -@ti.kernel -def _event_mv_prob_normal_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col]: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + @ti.kernel + def _event_mv_prob_normal_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col]: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_bool_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + if events[i_col]: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_normal_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col]: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col]: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_bool_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison without if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_bool_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison without if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -@ti.kernel -def _event_mv_prob_normal_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - if events[i_col] != 0.: - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + @ti.kernel + def _event_mv_prob_normal_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + if events[i_col] != 0.: + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_cpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + if events[i_col] != 0.: + r += row_v key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_cpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _event_mv_prob_normal_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 if events[i_col] != 0.: - r += row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _event_mv_prob_normal_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - if events[i_col] != 0.: + index = i & 31 + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _event_mv_prob_normal_outdim_parallel_gpu( + events: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = events.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 index = i & 31 - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) + i_col = step * index - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: + i_col += inc + while i_col < end_col: key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v + r += row_v * events[i_col] # TODO: speed comparison with if else key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _event_mv_prob_normal_outdim_parallel_gpu( - events: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = events.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - index = i & 31 - i_col = step * index - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += row_v * events[i_col] # TODO: speed comparison with if else - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction - - -def _event_mv_prob_normal_jvp_events( - evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_mu( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) - - -def _event_mv_prob_normal_jvp_w_sigma( - w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _event_mv_prob_normal_jvp_events( + evt_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(evt_dot, w_mu, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_normal_jvp_w_mu( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_dot, w_sigma, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def _event_mv_prob_normal_jvp_w_sigma( + w_dot, events, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(events, w_mu, w_dot, clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + + + def raw_event_mv_prob_normal( + events: jax.Array, + w_mu: jax.Array, # vector with size 1 + w_sigma: jax.Array, # vector with size 1 + conn_len: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_outdim_parallel_bool_p + else: + prim = _event_mv_prob_normal_outdim_parallel_p + else: + if events.dtype == jnp.bool_: + prim = _event_mv_prob_normal_bool_p + else: + prim = _event_mv_prob_normal_p + + return prim(events, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_event_mv_prob_normal_jvp_events, + _event_mv_prob_normal_jvp_w_mu, + _event_mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + + # outdim_parallel = True, events.dtype = jnp.bool_ + _event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu + ) + # outdim_parallel = False, events.dtype = jnp.bool_ + _event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_bool_cpu, + gpu_kernel=_event_mv_prob_normal_bool_gpu + ) -def raw_event_mv_prob_normal( - events: jax.Array, - w_mu: jax.Array, # vector with size 1 - w_sigma: jax.Array, # vector with size 1 - conn_len: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _event_checking(events, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + # outdim_parallel = True, events.dtype != jnp.bool_ + _event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu + ) - if outdim_parallel: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_outdim_parallel_bool_p - else: - prim = _event_mv_prob_normal_outdim_parallel_p - else: - if events.dtype == jnp.bool_: - prim = _event_mv_prob_normal_bool_p - else: - prim = _event_mv_prob_normal_p - - return prim(events, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=w_mu.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_event_mv_prob_normal_jvp_events, - _event_mv_prob_normal_jvp_w_mu, - _event_mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True, events.dtype = jnp.bool_ -_event_mv_prob_normal_outdim_parallel_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_bool_gpu -) - -# outdim_parallel = False, events.dtype = jnp.bool_ -_event_mv_prob_normal_bool_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_bool_cpu, - gpu_kernel=_event_mv_prob_normal_bool_gpu -) - -# outdim_parallel = True, events.dtype != jnp.bool_ -_event_mv_prob_normal_outdim_parallel_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_event_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False, events.dtype != jnp.bool_ -_event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( - cpu_kernel=_event_mv_prob_normal_cpu, - gpu_kernel=_event_mv_prob_normal_gpu -) + # outdim_parallel = False, events.dtype != jnp.bool_ + _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( + cpu_kernel=_event_mv_prob_normal_cpu, + gpu_kernel=_event_mv_prob_normal_gpu + ) diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index 0caa9c996..84abb9805 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -1,22 +1,18 @@ # -*- coding: utf-8 -*- -from functools import partial from typing import Tuple, Optional, Union import jax import numpy as np -from jax import numpy as jnp, dtypes -from jax.core import ShapedArray, Primitive -from jax.interpreters import xla, ad -from jax.lib import xla_client +from jax import numpy as jnp +from jax.interpreters import ad -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_brainpylib_cpu_ops, import_taichi +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array, _get_dtype -from brainpy._src.math.op_register import register_general_batching, XLACustomOp -from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) -from brainpy.errors import GPUOperatorNotFound +from brainpy._src.math.op_register import XLACustomOp +from brainpy.errors import PackageMissingError ti = import_taichi() @@ -215,808 +211,11 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, + return mv_prob_uniform_taichi(vector, w_mu, w_sigma, conn_prob, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -### BRAINYPLIB ### - -def mv_prob_homo_brainpylib( - vector: Union[Array, jax.Array], - weight: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a scalar `weight` at each position. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - weight: float - The value of the random matrix. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - weight = jnp.atleast_1d(as_jax(weight)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_homo_p.bind(vector, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel, - )[0] - - -def mv_prob_uniform_brainpylib( - vector: jax.Array, - w_low: float, - w_high: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a uniform distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_low: float - Lower boundary of the output interval. - w_high: float - Upper boundary of the output interval. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - w_low = jnp.atleast_1d(as_jax(w_low)) - w_high = jnp.atleast_1d(as_jax(w_high)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_uniform_p.bind(vector, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def mv_prob_normal_brainpylib( - vector: jax.Array, - w_mu: float, - w_sigma: float, - conn_prob: float, - seed: Optional[int] = None, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - r"""Perform the :math:`y=M@v` operation, - where :math:`M` is just-in-time randomly generated with a normal distribution for its value. - - This operator support ``jit()``, ``vmap()``, ``grad()`` and ``pmap()`` etc. transformations - on CPU and GPU devices. - - .. warning:: - - This API may change in the future. - - In this operation, :math:`M` is the random matrix with a connection probability - `conn_prob`, and at each connection the value is the same scalar `weight`. - - When ``transpose=True``, we perform an operation of :math:`y=M^T@v`. - - .. note:: - - Note that the just-in-time generated :math:`M` (`transpose=False`) is - different from the generated :math:`M^T` (`transpose=True`). - - If you pursue the same :math:`M` and :math:`M^T` when performing the just-in-time - matrix generation, you should set ``outdim_parallel=True``, with the sacrifice of - the speed compared with ``outdim_parallel=False``. - - Parameters - ---------- - vector: Array, ndarray - The vector. - w_mu: float - Mean (centre) of the distribution. - w_sigma: float - Standard deviation (spread or “width”) of the distribution. Must be non-negative. - conn_prob: float - The connection probability. - shape: tuple of int - The matrix shape. - seed: int - The random number generation seed. - transpose: bool - Transpose the random matrix or not. - outdim_parallel: bool - Perform the parallel random generations along the out dimension or not. - It can be used to set the just-in-time generated :math:M^T: is the same - as the just-in-time generated :math:`M` when ``transpose=True``. - - Returns - ------- - out: Array, ndarray - The output of :math:`y = M @ v`. - """ - vector = as_jax(vector) - w_mu = jnp.atleast_1d(as_jax(w_mu)) - w_sigma = jnp.atleast_1d(as_jax(w_sigma)) - conn_prob = jnp.atleast_1d(as_jax(conn_prob)) - clen = jnp.asarray(jnp.ceil(1 / conn_prob) * 2 - 1, dtype=jnp.int32) - with jax.ensure_compile_time_eval(): - if seed is None: - seed = int(np.random.randint(0, int(1e8))) - seed = jnp.atleast_1d(as_jax(seed, dtype=jnp.int32)) - return mv_prob_normal_p.bind(vector, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel)[0] - - -def _matvec_prob_homo_abstract( - vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - assert _get_dtype(weight) in [jnp.float32, jnp.float64], '"weight" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - if transpose: - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec ({vector.shape[0]},) @ mat {shape}.') - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_homo_cpu_translation( - c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - out_type = b'_float' - elif out_dtype == jnp.float64: - out_type = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_homo' + out_type - else: - fn = b'cpu_matvec_atomic_prob_homo' + out_type - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - weight, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_homo_gpu_translation( - c, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_homo_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_homo_v2' + type_name - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, weight, clen, seed), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(weight), - c.get_shape(clen), - c.get_shape(seed)), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_homo_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, weight, clen, seed = primals - vector_dot, weight_dot, clen_dot, seed_dot = tangents - r = mv_prob_homo_p.bind(vector, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - if type(weight_dot) is ad.Zero: - if type(vector_dot) is ad.Zero: - raise ValueError - r_dot = mv_prob_homo_p.bind(vector_dot, - weight, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - elif type(vector_dot) is ad.Zero: - r_dot = mv_prob_homo_p.bind(vector, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - else: - r_dot = mv_prob_homo_p.bind(vector_dot, - weight_dot, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - return r, r_dot - - -def _matvec_prob_homo_transpose( - ct, vector, weight, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(weight) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - assert type(vector) is ad.UndefinedPrimal - r = mv_prob_homo_p.bind(ct[0], - weight, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, weight, clen, seed - - -mv_prob_homo_p = Primitive('matvec_prob_homo') -mv_prob_homo_p.multiple_results = True -mv_prob_homo_p.def_abstract_eval(_matvec_prob_homo_abstract) -mv_prob_homo_p.def_impl(partial(xla.apply_primitive, mv_prob_homo_p)) -# xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation -register_general_batching(mv_prob_homo_p) -ad.primitive_jvps[mv_prob_homo_p] = _matvec_prob_homo_jvp -ad.primitive_transposes[mv_prob_homo_p] = _matvec_prob_homo_transpose - - -def _matvec_prob_uniform_abstract( - vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - _w_low_dtype = _get_dtype(w_low) - _w_high_dtype = _get_dtype(w_low) - assert _w_low_dtype == _w_high_dtype, '"w_low" and "w_high" must be same typed.' - assert _w_low_dtype in [jnp.float32, jnp.float64], '"w_low" must be float valued.' - assert _w_high_dtype in [jnp.float32, jnp.float64], '"w_high" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if w_low.ndim != 1: - raise ValueError('w_low must be a 1D scalar.') - if w_high.ndim != 1: - raise ValueError('w_high must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen must be a 1D scalar.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - assert w_low.dtype == w_high.dtype == vector.dtype - - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_uniform_cpu_translation( - c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_uniform' + type_name - else: - fn = b'cpu_matvec_atomic_prob_uniform' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_low, - w_high, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_uniform_gpu_translation( - c, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError(f'Only support float or double, while got {out_dtype}') - - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_uniform_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_uniform_v2' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, w_low, w_high, clen, seed), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_low), - c.get_shape(w_high), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_uniform_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, w_low, w_high, clen, seed = primals - vector_dot, w_low_dot, w_high_dot, clen_dot, seed_dot = tangents - r = mv_prob_uniform_p.bind(vector, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_low_dot) is ad.Zero - assert type(w_high_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_uniform_p.bind(vector_dot, - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _matvec_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(vector) is ad.UndefinedPrimal - assert type(w_low) is not ad.UndefinedPrimal - assert type(w_high) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_uniform_p.bind(ct[0], - w_low, - w_high, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_low, w_high, clen, seed - - -mv_prob_uniform_p = Primitive('matvec_prob_uniform') -mv_prob_uniform_p.multiple_results = True -mv_prob_uniform_p.def_abstract_eval(_matvec_prob_uniform_abstract) -mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, mv_prob_uniform_p)) -# xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation -register_general_batching(mv_prob_uniform_p) -ad.primitive_jvps[mv_prob_uniform_p] = _matvec_prob_uniform_jvp -ad.primitive_transposes[mv_prob_uniform_p] = _matvec_prob_uniform_transpose - - -def _matvec_prob_normal_abstract( - vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert _get_dtype(vector) in [jnp.float32, jnp.float64] - _w_mu_dtype = _get_dtype(w_mu) - _w_sigma_dtype = _get_dtype(w_sigma) - assert _w_mu_dtype == _w_sigma_dtype, '"w_mu" and "w_sigma" must be same typed.' - assert _w_mu_dtype in [jnp.float32, jnp.float64], '"w_mu" must be float valued.' - assert _w_sigma_dtype in [jnp.float32, jnp.float64], '"w_sigma" must be float valued.' - assert _get_dtype(clen) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int32, jnp.int64, jnp.uint32, jnp.uint64] - - if w_mu.ndim != 1: - raise ValueError('w_mu should be a 1D scalar.') - if w_sigma.ndim != 1: - raise ValueError('w_sigma should be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('clen should be a 1D scalar.') - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be a boolean value.') - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be a boolean value.') - - out = ShapedArray(dtype=dtypes.canonicalize_dtype(float), - shape=(shape[1] if transpose else shape[0],)) - return [out] - - -def _matvec_prob_normal_cpu_translation( - c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - import_brainpylib_cpu_ops() - n_row, n_col = (shape[1], shape[0]) if transpose else shape - - vec_shape = c.get_shape(vector) - out_dtype = vec_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError - - if outdim_parallel: - fn = b'cpu_matvec_prob_normal' + type_name - else: - fn = b'cpu_matvec_atomic_prob_normal' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_mu, - w_sigma, - clen, - seed, - xla_client.ops.ConstantLiteral(c, n_row), - xla_client.ops.ConstantLiteral(c, n_col)), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ()), - xla_client.Shape.array_shape(np.dtype(np.uint32), (), ())), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - ) - - -def _matvec_prob_normal_gpu_translation( - c, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(mv_prob_homo_p.name) - - event_shape = c.get_shape(vector) - out_dtype = event_shape.element_type() - - if out_dtype == jnp.float32: - type_name = b'_float' - elif out_dtype == jnp.float64: - type_name = b'_double' - else: - raise TypeError(f'Only support float or double, while got {out_dtype}') - opaque = gpu_ops.build_double_size_descriptor(shape[1] if transpose else shape[0], - shape[0] if transpose else shape[1]) - - if outdim_parallel: - fn = b'gpu_jit_csrmv_prob_normal_v2' + type_name - else: - fn = b'gpu_jit_csrmv_atomic_prob_normal_v2' + type_name - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(vector, - w_mu, - w_sigma, - clen, - seed,), - operand_shapes_with_layout=(c.get_shape(vector), - c.get_shape(w_mu), - c.get_shape(w_sigma), - c.get_shape(clen), - c.get_shape(seed),), - shape_with_layout=xla_client.Shape.tuple_shape( - ( - xla_client.Shape.array_shape(out_dtype, (shape[1] if transpose else shape[0],), (0,)), - ) - ), - opaque=opaque, - ) - - -def _matvec_prob_normal_jvp( - primals, tangents, *, shape, transpose, outdim_parallel -): - vector, w_mu, w_sigma, clen, seed = primals - vector_dot, w_mu_dot, w_sigma_dot, clen_dot, seed_dot = tangents - r = mv_prob_normal_p.bind(vector, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - assert type(w_mu_dot) is ad.Zero - assert type(w_sigma_dot) is ad.Zero - assert type(clen_dot) is ad.Zero - assert type(seed_dot) is ad.Zero - r_dot = mv_prob_normal_p.bind(vector_dot, - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - return r, r_dot - - -def _matvec_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, shape, transpose, outdim_parallel -): - assert type(vector) is ad.UndefinedPrimal - assert type(w_mu) is not ad.UndefinedPrimal - assert type(w_sigma) is not ad.UndefinedPrimal - assert type(clen) is not ad.UndefinedPrimal - assert type(seed) is not ad.UndefinedPrimal - - r = mv_prob_normal_p.bind(ct[0], - w_mu, - w_sigma, - clen, - seed, - shape=shape, - transpose=not transpose, - outdim_parallel=not outdim_parallel)[0] - return r, w_mu, w_sigma, clen, seed - - -mv_prob_normal_p = Primitive('matvec_prob_normal') -mv_prob_normal_p.multiple_results = True -mv_prob_normal_p.def_abstract_eval(_matvec_prob_normal_abstract) -mv_prob_normal_p.def_impl(partial(xla.apply_primitive, mv_prob_normal_p)) -# xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation -# xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation -register_general_batching(mv_prob_normal_p) -ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp -ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose - - -### TAICHI ### def mv_prob_homo_taichi( vector: Union[Array, jax.Array], weight: float, @@ -1081,6 +280,9 @@ def mv_prob_homo_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + vector = as_jax(vector) if isinstance(weight, float): weight = as_jax(weight, dtype=vector.dtype) @@ -1157,6 +359,9 @@ def mv_prob_uniform_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + vector = as_jax(vector) if isinstance(w_low, float): w_low = as_jax(w_low, dtype=vector.dtype) if isinstance(w_high, float): w_high = as_jax(w_high, dtype=vector.dtype) @@ -1233,6 +438,9 @@ def mv_prob_normal_taichi( out: Array, ndarray The output of :math:`y = M @ v`. """ + if ti is None: + raise PackageMissingError(name='taichi==1.7.0', purpose='customized operators') + vector = as_jax(vector) if isinstance(w_mu, float): w_mu = as_jax(w_mu, dtype=vector.dtype) if isinstance(w_sigma, float): w_sigma = as_jax(w_sigma, dtype=vector.dtype) @@ -1252,654 +460,657 @@ def _reverse(shape): return shape[::-1] -@ti.kernel -def _mv_prob_homo_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - v = vector[i_col] * weight0 - while i_row < num_row: - out[i_row] += v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - r += vector[i_col] - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r * weight0 - - -@ti.kernel -def _mv_prob_homo_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - out[i_row] += weight0 * col_v +if ti is not None: + from brainpy._src.math.tifunc import (lfsr88_key, lfsr88_random_integers, lfsr88_uniform, lfsr88_normal) + + @ti.kernel + def _mv_prob_homo_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + v = vector[i_col] * weight0 + while i_row < num_row: + out[i_row] += v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_homo_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r * weight0 + + + @ti.kernel + def _mv_prob_homo_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _mv_prob_homo_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - weight0 = weight[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - r += vector[i_col] + while i_row < end: + out[i_row] += weight0 * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_homo_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + weight0 = weight[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += weight0 * r # TODO: warp-level reduction + while i_col < end_col: + r += vector[i_col] + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += weight0 * r # TODO: warp-level reduction -def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_homo_jvp_vector(v_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(v_dot, weight, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_homo_jvp_weight(w_dot, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_homo(vector, w_dot, clen, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_homo_transpose( - ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), weight, clen, seed + def _mv_prob_homo_transpose( + ct, vector, weight, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), weight, clen, seed + else: + dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, weight, clen, seed + elif ad.is_undefined_primal(weight): + if type(ct) is ad.Zero: + return vector, ad.Zero(weight), clen, seed + else: + row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, + shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + dw = jnp.sum(row * vector, keepdims=True) + return vector, dw, clen, seed else: - dv = raw_mv_prob_homo(ct[0], weight, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, weight, clen, seed - elif ad.is_undefined_primal(weight): - if type(ct) is ad.Zero: - return vector, ad.Zero(weight), clen, seed + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + + def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + if vector.ndim != 1: + raise ValueError('vector should be a 1D vector.') + if len(shape) != 2: + raise ValueError('shape should be a length-2 tuple.') + if seed.ndim != 1: + raise ValueError('seed must be a 1D scalar.') + if clen.ndim != 1: + raise ValueError('conn_prob must be a 1D scalar.') + + assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] + + for weight in weights: + if weight.ndim != 1: + raise ValueError('weight must be a 1D scalar.') + assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' + + if not isinstance(outdim_parallel, bool): + raise ValueError('outdim_parallel must be boolean value.') + if not isinstance(transpose, bool): + raise ValueError('transpose must be boolean value.') + + if transpose: + out_shape = (shape[1],) + if vector.shape[0] != shape[0]: + raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') + shape = _reverse(shape) else: - row = raw_mv_prob_homo(ct[0], jnp.ones(1, dtype=ct[0].dtype), clen, seed, - shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] - dw = jnp.sum(row * vector, keepdims=True) - return vector, dw, clen, seed - else: - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - if vector.ndim != 1: - raise ValueError('vector should be a 1D vector.') - if len(shape) != 2: - raise ValueError('shape should be a length-2 tuple.') - if seed.ndim != 1: - raise ValueError('seed must be a 1D scalar.') - if clen.ndim != 1: - raise ValueError('conn_prob must be a 1D scalar.') - - assert _get_dtype(clen) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - assert _get_dtype(seed) in [jnp.int16, jnp.int32, jnp.int64, jnp.uint16, jnp.uint32, jnp.uint64] - - for weight in weights: - if weight.ndim != 1: - raise ValueError('weight must be a 1D scalar.') - assert _get_dtype(weight) in [jnp.float16, jnp.float32, jnp.float64], '"weight" must be float valued.' - - if not isinstance(outdim_parallel, bool): - raise ValueError('outdim_parallel must be boolean value.') - if not isinstance(transpose, bool): - raise ValueError('transpose must be boolean value.') - - if transpose: - out_shape = (shape[1],) - if vector.shape[0] != shape[0]: - raise ValueError(f'Shape mismatch, vec {vector.shape} @ mat {shape}.') - shape = _reverse(shape) - else: - if vector.shape[0] != shape[1]: - raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') - out_shape = (shape[0],) - - return shape, out_shape - - -def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): - assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] - return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) - - -def raw_mv_prob_homo( - vector: jax.Array, - weight: jax.Array, # vector with size 1 - clen: jax.Array, # vector with size 1 - seed: jax.Array, # vector with size 1 - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) - - if outdim_parallel: - prim = _mv_prob_homo_outdim_parallel_p - else: - prim = _mv_prob_homo_p - - return prim(vector, - weight, - clen, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) - prim.def_transpose_rule(_mv_prob_homo_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, - gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) - -# outdim_parallel = False -_mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, - gpu_kernel=_mv_prob_homo_gpu) - - -@ti.kernel -def _mv_prob_uniform_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + if vector.shape[0] != shape[1]: + raise ValueError(f'Shape mismatch, mat {shape} @ vec ({vector.shape[0]},).') + out_shape = (shape[0],) + return shape, out_shape -@ti.kernel -def _mv_prob_uniform_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_uniform_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - out[i_row] += row_v * col_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc + def _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights): + assert _get_dtype(vector) in [jnp.float16, jnp.float32, jnp.float64] + return _general_checking(vector, clen, seed, shape, outdim_parallel, transpose, *weights) -@ti.kernel -def _mv_prob_uniform_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_min: ti.types.ndarray(ndim=1), - w_max: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_min0 = w_min[0] - w_max0 = w_max[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_uniform(key, w_min0, w_max0) - r += vector[i_col] * row_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] += r # TODO: warp-level reduction + def raw_mv_prob_homo( + vector: jax.Array, + weight: jax.Array, # vector with size 1 + clen: jax.Array, # vector with size 1 + seed: jax.Array, # vector with size 1 + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, clen, seed, shape, outdim_parallel, transpose, weight) -def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, - outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - - -def _mv_prob_uniform_transpose( - ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_low, w_high, clen, seed + if outdim_parallel: + prim = _mv_prob_homo_outdim_parallel_p else: - dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_low, w_high, clen, seed - else: - assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' - assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - - -def raw_mv_prob_uniform( - vector: jax.Array, - w_low: jax.Array, - w_high: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) - - if outdim_parallel: - prim = _mv_prob_uniform_outdim_parallel_p - else: - prim = _mv_prob_uniform_p - - return prim(vector, - w_low, - w_high, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_uniform_jvp_vector, - _mv_prob_uniform_jvp_wlow, - _mv_prob_uniform_jvp_whigh, - None, - None) - prim.def_transpose_rule(_mv_prob_uniform_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, - gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_uniform_p = _define_mv_prob_uniform_prim( - cpu_kernel=_mv_prob_uniform_cpu, - gpu_kernel=_mv_prob_uniform_gpu -) - - -@ti.kernel -def _mv_prob_normal_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_col in range(num_col): - col_v = vector[i_col] - key = lfsr88_key(seed0 + i_col) - key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) - while i_row < num_row: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += col_v * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_cpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - - for i_row in range(num_row): - r = 0. - key = lfsr88_key(seed0 + i_row) - key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) - while i_col < num_col: - key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * raw_v - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - out[i_row] = r - - -@ti.kernel -def _mv_prob_normal_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_col * 32): - i_col = i >> 5 - index = i & 31 - col_v = vector[i_col] - i_row = step * index - 1 - end = ti.min(i_row + step, num_row) - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_row += inc - while i_row < end: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - out[i_row] += row_v * col_v + prim = _mv_prob_homo_p + + return prim(vector, + weight, + clen, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_mv_prob_homo_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_homo_jvp_vector, _mv_prob_homo_jvp_weight, None, None) + prim.def_transpose_rule(_mv_prob_homo_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_homo_outdim_parallel_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_outdim_parallel_cpu, + gpu_kernel=_mv_prob_homo_outdim_parallel_gpu) + + # outdim_parallel = False + _mv_prob_homo_p = _define_mv_prob_homo_prim(cpu_kernel=_mv_prob_homo_cpu, + gpu_kernel=_mv_prob_homo_gpu) + + + @ti.kernel + def _mv_prob_uniform_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_uniform_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _mv_prob_uniform_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_row += inc - - -@ti.kernel -def _mv_prob_normal_outdim_parallel_gpu( - vector: ti.types.ndarray(ndim=1), - w_mu: ti.types.ndarray(ndim=1), - w_sigma: ti.types.ndarray(ndim=1), - clen: ti.types.ndarray(ndim=1), - seed: ti.types.ndarray(ndim=1), - out: ti.types.ndarray(ndim=1) -): - num_row = out.shape[0] - num_col = vector.shape[0] - w_mu0 = w_mu[0] - w_sigma0 = w_sigma[0] - clen0 = clen[0] - seed0 = seed[0] - step = ti.u32(ti.max((num_row + 1) >> 5, 1)) - - for i in range(num_row * 32): - i_row = i >> 5 - i_thread = i & 31 - i_col = step * i_thread - 1 - end_col = ti.min(i_col + step, num_col) - r = 0. - key = lfsr88_key(seed0 + i) - key, inc = lfsr88_random_integers(key, 1, clen0) - i_col += inc - while i_col < end_col: - key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) - r += vector[i_col] * row_v + while i_row < end: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_uniform_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_min: ti.types.ndarray(ndim=1), + w_max: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_min0 = w_min[0] + w_max0 = w_max[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) key, inc = lfsr88_random_integers(key, 1, clen0) i_col += inc - out[i_row] += r # TODO: warp-level reduction + while i_col < end_col: + key, row_v = lfsr88_uniform(key, w_min0, w_max0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction -def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + def _mv_prob_uniform_jvp_vector(v_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(v_dot, w_low, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) -def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) - + def _mv_prob_uniform_jvp_wlow(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_dot, w_high, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_uniform_jvp_whigh(w_dot, vector, w_low, w_high, clen, seed, *, + outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_uniform(vector, w_low, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_uniform_transpose( + ct, vector, w_low, w_high, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_low, w_high, clen, seed + else: + dv = raw_mv_prob_uniform(ct[0], w_low, w_high, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_low, w_high, clen, seed + else: + assert type(w_low) is not ad.UndefinedPrimal, 'Cannot differentiate through w_low.' + assert type(w_high) is not ad.UndefinedPrimal, 'Cannot differentiate through w_high.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + + def raw_mv_prob_uniform( + vector: jax.Array, + w_low: jax.Array, + w_high: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_low, w_high) + + if outdim_parallel: + prim = _mv_prob_uniform_outdim_parallel_p + else: + prim = _mv_prob_uniform_p + + return prim(vector, + w_low, + w_high, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_mv_prob_uniform_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_uniform_jvp_vector, + _mv_prob_uniform_jvp_wlow, + _mv_prob_uniform_jvp_whigh, + None, + None) + prim.def_transpose_rule(_mv_prob_uniform_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_uniform_outdim_parallel_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_outdim_parallel_cpu, + gpu_kernel=_mv_prob_uniform_outdim_parallel_gpu + ) -def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): - shape = _reverse(shape) if transpose else shape - return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, - transpose=transpose, outdim_parallel=outdim_parallel) + # outdim_parallel = False + _mv_prob_uniform_p = _define_mv_prob_uniform_prim( + cpu_kernel=_mv_prob_uniform_cpu, + gpu_kernel=_mv_prob_uniform_gpu + ) -def _mv_prob_normal_transpose( - ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel -): - shape = _reverse(shape) if transpose else shape - if ad.is_undefined_primal(vector): - if type(ct) is ad.Zero: - return ad.Zero(vector), w_mu, w_sigma, clen, seed + @ti.kernel + def _mv_prob_normal_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_col in range(num_col): + col_v = vector[i_col] + key = lfsr88_key(seed0 + i_col) + key, i_row = lfsr88_random_integers(key, 0, clen0 - 1) + while i_row < num_row: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += col_v * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_normal_outdim_parallel_cpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + + for i_row in range(num_row): + r = 0. + key = lfsr88_key(seed0 + i_row) + key, i_col = lfsr88_random_integers(key, 0, clen0 - 1) + while i_col < num_col: + key, raw_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * raw_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] = r + + + @ti.kernel + def _mv_prob_normal_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.uint32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_col * 32): + i_col = i >> 5 + index = i & 31 + col_v = vector[i_col] + i_row = step * index - 1 + end = ti.min(i_row + step, num_row) + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + while i_row < end: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + out[i_row] += row_v * col_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_row += inc + + + @ti.kernel + def _mv_prob_normal_outdim_parallel_gpu( + vector: ti.types.ndarray(ndim=1), + w_mu: ti.types.ndarray(ndim=1), + w_sigma: ti.types.ndarray(ndim=1), + clen: ti.types.ndarray(ndim=1), + seed: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1) + ): + num_row = out.shape[0] + num_col = vector.shape[0] + w_mu0 = w_mu[0] + w_sigma0 = w_sigma[0] + clen0 = clen[0] + seed0 = seed[0] + step = ti.u32(ti.max((num_row + 1) >> 5, 1)) + + for i in range(num_row * 32): + i_row = i >> 5 + i_thread = i & 31 + i_col = step * i_thread - 1 + end_col = ti.min(i_col + step, num_col) + r = 0. + key = lfsr88_key(seed0 + i) + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + while i_col < end_col: + key, row_v = lfsr88_normal(key, w_mu0, w_sigma0) + r += vector[i_col] * row_v + key, inc = lfsr88_random_integers(key, 1, clen0) + i_col += inc + out[i_row] += r # TODO: warp-level reduction + + + def _mv_prob_normal_jvp_vector(v_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(v_dot, w_mu, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_normal_jvp_w_mu(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_dot, w_sigma, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_normal_jvp_w_sigma(w_dot, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel): + shape = _reverse(shape) if transpose else shape + return raw_mv_prob_normal(vector, w_mu, w_dot, clen, seed, shape=shape, + transpose=transpose, outdim_parallel=outdim_parallel) + + + def _mv_prob_normal_transpose( + ct, vector, w_mu, w_sigma, clen, seed, *, outs, shape, transpose, outdim_parallel + ): + shape = _reverse(shape) if transpose else shape + if ad.is_undefined_primal(vector): + if type(ct) is ad.Zero: + return ad.Zero(vector), w_mu, w_sigma, clen, seed + else: + dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, + transpose=not transpose, outdim_parallel=not outdim_parallel)[0] + return dv, w_mu, w_sigma, clen, seed else: - dv = raw_mv_prob_normal(ct[0], w_mu, w_sigma, clen, seed, shape=shape, - transpose=not transpose, outdim_parallel=not outdim_parallel)[0] - return dv, w_mu, w_sigma, clen, seed - else: - assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' - assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' - assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' - assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' - + assert type(w_mu) is not ad.UndefinedPrimal, 'Cannot differentiate through w_mu.' + assert type(w_sigma) is not ad.UndefinedPrimal, 'Cannot differentiate through w_sigma.' + assert type(clen) is not ad.UndefinedPrimal, 'Cannot differentiate through clen.' + assert type(seed) is not ad.UndefinedPrimal, 'Cannot differentiate through seed.' + + + def raw_mv_prob_normal( + vector: jax.Array, + w_mu: jax.Array, + w_sigma: jax.Array, + conn_len: jax.Array, + seed: jax.Array, + *, + shape: Tuple[int, int], + transpose: bool = False, + outdim_parallel: bool = True, + ) -> jax.Array: + mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) + + if outdim_parallel: + prim = _mv_prob_normal_outdim_parallel_p + else: + prim = _mv_prob_normal_p + + return prim(vector, + w_mu, + w_sigma, + conn_len, + seed, + outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], + shape=mat_shape, + transpose=transpose, + outdim_parallel=outdim_parallel) + + + def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): + prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) + prim.defjvp(_mv_prob_normal_jvp_vector, + _mv_prob_normal_jvp_w_mu, + _mv_prob_normal_jvp_w_sigma, + None, + None) + prim.def_transpose_rule(_mv_prob_normal_transpose) + return prim + + + # outdim_parallel = True + _mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, + gpu_kernel=_mv_prob_normal_outdim_parallel_gpu + ) -def raw_mv_prob_normal( - vector: jax.Array, - w_mu: jax.Array, - w_sigma: jax.Array, - conn_len: jax.Array, - seed: jax.Array, - *, - shape: Tuple[int, int], - transpose: bool = False, - outdim_parallel: bool = True, -) -> jax.Array: - mat_shape, out_shape = _non_event_checking(vector, conn_len, seed, shape, outdim_parallel, transpose, w_mu, w_sigma) - - if outdim_parallel: - prim = _mv_prob_normal_outdim_parallel_p - else: - prim = _mv_prob_normal_p - - return prim(vector, - w_mu, - w_sigma, - conn_len, - seed, - outs=[jax.ShapeDtypeStruct(shape=out_shape, dtype=vector.dtype)], - shape=mat_shape, - transpose=transpose, - outdim_parallel=outdim_parallel) - - -def _define_mv_prob_normal_prim(cpu_kernel, gpu_kernel): - prim = XLACustomOp(cpu_kernel=cpu_kernel, gpu_kernel=gpu_kernel) - prim.defjvp(_mv_prob_normal_jvp_vector, - _mv_prob_normal_jvp_w_mu, - _mv_prob_normal_jvp_w_sigma, - None, - None) - prim.def_transpose_rule(_mv_prob_normal_transpose) - return prim - - -# outdim_parallel = True -_mv_prob_normal_outdim_parallel_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_outdim_parallel_cpu, - gpu_kernel=_mv_prob_normal_outdim_parallel_gpu -) - -# outdim_parallel = False -_mv_prob_normal_p = _define_mv_prob_normal_prim( - cpu_kernel=_mv_prob_normal_cpu, - gpu_kernel=_mv_prob_normal_gpu -) + # outdim_parallel = False + _mv_prob_normal_p = _define_mv_prob_normal_prim( + cpu_kernel=_mv_prob_normal_cpu, + gpu_kernel=_mv_prob_normal_gpu + ) diff --git a/brainpy/_src/math/jitconn/tests/test_event_matvec.py b/brainpy/_src/math/jitconn/tests/test_event_matvec.py index b10d55d21..034885ae9 100644 --- a/brainpy/_src/math/jitconn/tests/test_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_event_matvec.py @@ -4,8 +4,14 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +import pytest import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi() is None: + pytest.skip('no taichi', allow_module_level=True) + shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/jitconn/tests/test_matvec.py b/brainpy/_src/math/jitconn/tests/test_matvec.py index 2e6e406cf..caee4efbe 100644 --- a/brainpy/_src/math/jitconn/tests/test_matvec.py +++ b/brainpy/_src/math/jitconn/tests/test_matvec.py @@ -4,8 +4,13 @@ import jax import jax.numpy as jnp from absl.testing import parameterized +import pytest import brainpy.math as bm +from brainpy._src.dependency_check import import_taichi + +if import_taichi() is None: + pytest.skip('no taichi', allow_module_level=True) shapes = [(100, 200), (10, 1000), (2, 1000), (1000, 10), (1000, 2)] shapes = [(100, 200), (2, 1000), (1000, 2)] diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index fb76aed24..fd7a289ed 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -10,7 +10,6 @@ from .utils import _shape_to_layout - __all__ = [ 'register_numba_xla_cpu_translation_rule', 'register_numba_mlir_cpu_translation_rule', diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index d45f2c80b..6c13ac19a 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,8 +1,8 @@ -from ._coo_mv import * +# from ._coo_mv import * +# from ._bsr_mv import * from ._csr_mv import * from ._utils import * -from ._bsr_mv import * from ._bsr_mm import * from ._jax_prim import * diff --git a/brainpy/_src/math/sparse/tests/test_csrmv.py b/brainpy/_src/math/sparse/tests/test_csrmv.py index 2c75f0901..418a52d35 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv.py @@ -5,10 +5,14 @@ import jax from absl.testing import parameterized +import pytest import brainpy as bp import brainpy.math as bm -# bm.set_platform('gpu') +from brainpy._src.dependency_check import import_taichi + +if import_taichi() is None: + pytest.skip('no taichi', allow_module_level=True) seed = 1234 diff --git a/brainpy/_src/math/sparse/tests/test_csrmv_old.py b/brainpy/_src/math/sparse/tests/test_csrmv_old.py index b73217496..23a3de93a 100644 --- a/brainpy/_src/math/sparse/tests/test_csrmv_old.py +++ b/brainpy/_src/math/sparse/tests/test_csrmv_old.py @@ -4,16 +4,12 @@ import jax import pytest -from absl.testing import parameterized -import platform + import brainpy as bp import brainpy.math as bm pytest.skip('Old implementation.', allow_module_level=True) -is_manual_test = False -# if platform.system() == 'Windows' and not is_manual_test: -# pytest.skip('brainpy.math package may need manual tests.', allow_module_level=True) cusparse_csr_matvec = partial(bm.sparse.csrmv, method='cusparse') scalar_csr_matvec = partial(bm.sparse.csrmv, method='scalar') diff --git a/brainpy/_src/math/tifunc.py b/brainpy/_src/math/tifunc.py index a9ee39f4a..928cb345a 100644 --- a/brainpy/_src/math/tifunc.py +++ b/brainpy/_src/math/tifunc.py @@ -3,362 +3,368 @@ ti = import_taichi() -__all__ = [ - # taichi function for other utilities - 'warp_reduce_sum', +if ti is not None: - # taichi functions for random number generator with LFSR88 algorithm - 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn', - 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand', + __all__ = [ + # taichi function for other utilities + 'warp_reduce_sum', - # taichi functions for random number generator with LFSR113 algorithm - 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn', - 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', -] + # taichi functions for random number generator with LFSR88 algorithm + 'lfsr88_key', 'lfsr88_next_key', 'lfsr88_normal', 'lfsr88_randn', + 'lfsr88_random_integers', 'lfsr88_randint', 'lfsr88_uniform', 'lfsr88_rand', + # taichi functions for random number generator with LFSR113 algorithm + 'lfsr113_key', 'lfsr113_next_key', 'lfsr113_normal', 'lfsr113_randn', + 'lfsr113_random_integers', 'lfsr113_randint', 'lfsr113_uniform', 'lfsr113_rand', + ] -@ti.func -def _lcg_rand(state: ti.types.ndarray(ndim=1)): - # LCG constants - state[0] = ti.u32(1664525) * state[0] + ti.u32(1013904223) - return state[0] + @ti.func + def _lcg_rand(state: ti.types.ndarray(ndim=1)): + # LCG constants + state[0] = ti.u32(1664525) * state[0] + ti.u32(1013904223) + return state[0] -@ti.func -def taichi_lcg_rand(seed: ti.types.ndarray(ndim=1)): - """ - Generate a random number using the Taichi LCG algorithm. - Parameters: - seed (ti.types.ndarray): The seed value for the random number generator. + @ti.func + def taichi_lcg_rand(seed: ti.types.ndarray(ndim=1)): + """ + Generate a random number using the Taichi LCG algorithm. - Returns: - float: A random number between 0 and 1. - """ + Parameters: + seed (ti.types.ndarray): The seed value for the random number generator. - return float(_lcg_rand(seed)) / ti.u32(2 ** 32 - 1) + Returns: + float: A random number between 0 and 1. + """ + return float(_lcg_rand(seed)) / ti.u32(2 ** 32 - 1) -############################################# -# Random Number Generator: LFSR88 algorithm # -############################################# + ############################################# + # Random Number Generator: LFSR88 algorithm # + ############################################# -@ti.func -def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. + @ti.func + def lfsr88_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): + """Initialize the random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c + This key is used in LFSR88 based random number generator functions, like ``lfsr88_rand()``. - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3 MUST be larger than - 1, 7, and 15 respectively. - */ + Source: + https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr88.c - Args: - seed: int. The seed value for the random number generator. + /**** VERY IMPORTANT **** : + The initial seeds s1, s2, s3 MUST be larger than + 1, 7, and 15 respectively. + */ - Returns: - ti.math.uvec4: The random key for the LFSR88 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) + Args: + seed: int. The seed value for the random number generator. + Returns: + ti.math.uvec4: The random key for the LFSR88 random number generator. + """ + return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(0)) -@ti.func -def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - Args: - key: The state value for the random number generator. + @ti.func + def lfsr88_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): + """Next random key of LFSR88 algorithm (Combined LFSR random number generator by L'Ecuyer). - Returns: - ti.math.uvec4: The next random key. - """ - b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) - s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b - b = ((key[1] << 2) ^ key[1]) >> 25 - s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b - b = ((key[2] << 3) ^ key[2]) >> 11 - s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b - return ti.math.uvec4(s1, s2, s3, b) + Args: + key: The state value for the random number generator. + Returns: + ti.math.uvec4: The next random key. + """ + b = ti.u32(((key[0] << 13) ^ key[0]) >> 19) + s1 = ((key[0] & ti.u32(4294967294)) << 12) ^ b + b = ((key[1] << 2) ^ key[1]) >> 25 + s2 = ((key[1] & ti.u32(4294967288)) << 4) ^ b + b = ((key[2] << 3) ^ key[2]) >> 11 + s3 = ((key[2] & ti.u32(4294967280)) << 17) ^ b + return ti.math.uvec4(s1, s2, s3, b) -@ti.func -def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ + @ti.func + def lfsr88_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): + """ + Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR88 algorithm. - key, r = lfsr88_randn(key, epsilon) - return key, mu + sigma * r + Args: + key: The state value for the random number generator. + mu: The mean of the normal distribution. + sigma: The standard deviation of the normal distribution. + epsilon: The epsilon value to avoid log(0). + """ + key, r = lfsr88_randn(key, epsilon) + return key, mu + sigma * r -@ti.func -def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with the standard normal distribution using the LFSR88 algorithm. - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). + @ti.func + def lfsr88_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): + """ + Generate a random number with the standard normal distribution using the LFSR88 algorithm. - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + Args: + key: The state value for the random number generator. + epsilon: The epsilon value to avoid log(0). - """ + References: + Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method - key, u1 = lfsr88_rand(key) - key, u2 = lfsr88_rand(key) + """ - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) + key, u1 = lfsr88_rand(key) + key, u2 = lfsr88_rand(key) - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) + # Ensure state1 is not zero to avoid log(0) + u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + # Normalize the uniform samples + mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - return key, z2 + # Box-Muller transform + # z1 = mag * ti.cos(2 * ti.math.pi * u2) + z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + return key, z2 -@ti.func -def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) + @ti.func + def lfsr88_random_integers(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR88 algorithm. + Parameters: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + return key, ti.cast((key[0] ^ key[1] ^ key[2]) % (high + 1 - low) + low, defaults.ti_int) -@ti.func -def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): - key = lfsr88_next_key(key) - return key, dtype(key[0] ^ key[1] ^ key[2]) + @ti.func + def lfsr88_randint(key: ti.types.vector(4, ti.u32), dtype=ti.u32): + key = lfsr88_next_key(key) + return key, dtype(key[0] ^ key[1] ^ key[2]) -@ti.func -def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) + @ti.func + def lfsr88_uniform(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR88 algorithm. + Args: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + r = (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + return key, ti.cast(r * (high - low) + low, defaults.ti_float) -@ti.func -def lfsr88_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. - Args: - key: The state value used for random number generation. - """ - key = lfsr88_next_key(key) - return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + @ti.func + def lfsr88_rand(key: ti.types.vector(4, ti.u32)): + """ + Generates a uniformly distributed random float between 0 and 1 using the LFSR88 algorithm. + Args: + key: The state value used for random number generation. + """ + key = lfsr88_next_key(key) + return key, (key[0] ^ key[1] ^ key[2]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) -############################################## -# Random Number Generator: LFSR113 algorithm # -############################################## + ############################################## + # Random Number Generator: LFSR113 algorithm # + ############################################## -@ti.func -def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): - """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. + @ti.func + def lfsr113_key(seed: ti.u32) -> ti.types.vector(4, ti.u32): + """Initialize the random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - Source: - https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c + This key is used in LFSR113 based random number generator functions, like ``lfsr113_rand()``. - /**** VERY IMPORTANT **** : - The initial seeds s1, s2, s3, s4 MUST be larger than - 1, 7, 15, and 127 respectively. - */ + Source: + https://github.com/cmcqueen/simplerandom/blob/main/c/lecuyer/lfsr113.c - Args: - seed: int. The seed value for the random number generator. + /**** VERY IMPORTANT **** : + The initial seeds s1, s2, s3, s4 MUST be larger than + 1, 7, 15, and 127 respectively. + */ - Returns: - ti.math.uvec4: The random key for the LFSR113 random number generator. - """ - return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) + Args: + seed: int. The seed value for the random number generator. + Returns: + ti.math.uvec4: The random key for the LFSR113 random number generator. + """ + return ti.math.uvec4(ti.u32(seed + 1), ti.u32(seed + 7), ti.u32(seed + 15), ti.u32(seed + 127)) -@ti.func -def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): - """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - Args: - key: The state value for the random number generator. + @ti.func + def lfsr113_next_key(key: ti.types.vector(4, ti.u32)) -> ti.types.vector(4, ti.u32): + """Next random key of LFSR113 algorithm (Combined LFSR random number generator by L'Ecuyer). - Returns: - ti.math.uvec4: The next random key. - """ - z1 = key[0] - z2 = key[1] - z3 = key[2] - z4 = key[3] - b = ((z1 << 6) ^ z1) >> 13 - z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) - b = ((z2 << 2) ^ z2) >> 27 - z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) - b = ((z3 << 13) ^ z3) >> 21 - z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) - b = ((z4 << 3) ^ z4) >> 12 - z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) - return ti.math.uvec4(z1, z2, z3, z4) + Args: + key: The state value for the random number generator. + Returns: + ti.math.uvec4: The next random key. + """ + z1 = key[0] + z2 = key[1] + z3 = key[2] + z4 = key[3] + b = ((z1 << 6) ^ z1) >> 13 + z1 = ti.u32(((z1 & ti.u64(4294967294)) << 18) ^ b) + b = ((z2 << 2) ^ z2) >> 27 + z2 = ti.u32(((z2 & ti.u64(4294967288)) << 2) ^ b) + b = ((z3 << 13) ^ z3) >> 21 + z3 = ti.u32(((z3 & ti.u64(4294967280)) << 7) ^ b) + b = ((z4 << 3) ^ z4) >> 12 + z4 = ti.u32(((z4 & ti.u64(4294967168)) << 13) ^ b) + return ti.math.uvec4(z1, z2, z3, z4) -@ti.func -def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): - """ - Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. - Args: - key: The state value for the random number generator. - mu: The mean of the normal distribution. - sigma: The standard deviation of the normal distribution. - epsilon: The epsilon value to avoid log(0). - """ + @ti.func + def lfsr113_normal(key: ti.types.vector(4, ti.u32), mu, sigma, epsilon=1e-10): + """ + Generate a random number of the normal distribution ``N(mu, sigma)`` using the LFSR113 algorithm. - key, r = lfsr113_randn(key, epsilon) - return key, ti.cast(mu + sigma * r, defaults.ti_float) + Args: + key: The state value for the random number generator. + mu: The mean of the normal distribution. + sigma: The standard deviation of the normal distribution. + epsilon: The epsilon value to avoid log(0). + """ + key, r = lfsr113_randn(key, epsilon) + return key, ti.cast(mu + sigma * r, defaults.ti_float) -@ti.func -def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): - """ - Generate a random number with standard normal distribution using the LFSR113 algorithm. - Args: - key: The state value for the random number generator. - epsilon: The epsilon value to avoid log(0). + @ti.func + def lfsr113_randn(key: ti.types.vector(4, ti.u32), epsilon=1e-10): + """ + Generate a random number with standard normal distribution using the LFSR113 algorithm. - References: - Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform - Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method + Args: + key: The state value for the random number generator. + epsilon: The epsilon value to avoid log(0). - """ + References: + Box–Muller transform. https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform + Marsaglia polar method. https://en.wikipedia.org/wiki/Marsaglia_polar_method - key, u1 = lfsr113_rand(key) - key, u2 = lfsr113_rand(key) + """ - # Ensure state1 is not zero to avoid log(0) - u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) + key, u1 = lfsr113_rand(key) + key, u2 = lfsr113_rand(key) - # Normalize the uniform samples - mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) + # Ensure state1 is not zero to avoid log(0) + u1 = ti.cast(ti.max(u1, epsilon), defaults.ti_float) - # Box-Muller transform - # z1 = mag * ti.cos(2 * ti.math.pi * u2) - z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + # Normalize the uniform samples + mag = ti.cast(ti.sqrt(-2.0 * ti.log(u1)), defaults.ti_float) - return key, z2 + # Box-Muller transform + # z1 = mag * ti.cos(2 * ti.math.pi * u2) + z2 = ti.cast(mag * ti.sin(2 * ti.math.pi * u2), defaults.ti_float) + return key, z2 -@ti.func -def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. - Parameters: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr113_next_key(key) - return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) + @ti.func + def lfsr113_random_integers(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random integer between `low` and `high` (inclusive) using the LFSR113 algorithm. + Parameters: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr113_next_key(key) + return key, ti.cast((key[0] ^ key[1] ^ key[2] ^ key[3]) % (high + 1 - low) + low, defaults.ti_int) -@ti.func -def lfsr113_randint(key: ti.types.vector(4, ti.u32)): - key = lfsr113_next_key(key) - return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) + @ti.func + def lfsr113_randint(key: ti.types.vector(4, ti.u32)): + key = lfsr113_next_key(key) + return key, ti.cast(key[0] ^ key[1] ^ key[2] ^ key[3], defaults.ti_int) -@ti.func -def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): - """ - Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. - Args: - key: The state value used for random number generation. - low: The lower bound of the range. - high: The upper bound of the range. - """ - key = lfsr88_next_key(key) - r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) - return key, ti.cast(r * (high - low) + low, defaults.ti_float) - - -@ti.func -def lfsr113_rand(key: ti.types.vector(4, ti.u32)): - """ - Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. + @ti.func + def lfsr113_uniform(key: ti.types.vector(4, ti.u32), low, high): + """ + Generates a uniformly distributed random float between `low` and `high` (inclusive) using the LFSR113 algorithm. - Args: - key: The state value used for random number generation. - """ - key = lfsr113_next_key(key) - return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + Args: + key: The state value used for random number generation. + low: The lower bound of the range. + high: The upper bound of the range. + """ + key = lfsr88_next_key(key) + r = (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) + return key, ti.cast(r * (high - low) + low, defaults.ti_float) -########################### -# Reductions: warp reduce # -########################### + @ti.func + def lfsr113_rand(key: ti.types.vector(4, ti.u32)): + """ + Generates a uniformly distributed random float between 0 and 1 using the LFSR113 algorithm. + Args: + key: The state value used for random number generation. + """ + key = lfsr113_next_key(key) + return key, (key[0] ^ key[1] ^ key[2] ^ key[3]) * ti.cast(2.3283064365386963e-10, defaults.ti_float) -@ti.func -def warp_reduce_sum_all(val): - """ - Warp reduce sum. - Args: - val (float): The value to be reduced. + ########################### + # Reductions: warp reduce # + ########################### - Returns: - float: The reduced value. - """ - for i in ti.static(range(1, 32)): - val += ti.static(ti.simt.warp.shfl_xor(val, i)) - return val + @ti.func + def warp_reduce_sum_all(val): + """ + Warp reduce sum. -@ti.func -def warp_reduce_sum(val): - """ - Warp reduce sum. + Args: + val (float): The value to be reduced. - Args: - val (float): The value to be reduced. + Returns: + float: The reduced value. + """ + for i in ti.static(range(1, 32)): + val += ti.static(ti.simt.warp.shfl_xor(val, i)) + return val - Returns: - float: The reduced value. - """ - for offset in ti.static((16, 8, 4, 2, 1)): - val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) - return val + + @ti.func + def warp_reduce_sum(val): + """ + Warp reduce sum. + + Args: + val (float): The value to be reduced. + + Returns: + float: The reduced value. + """ + for offset in ti.static((16, 8, 4, 2, 1)): + val += ti.simt.warp.shfl_down_f32(ti.u32(0xFFFFFFFF), val, offset) + return val + + +else: + __all__ = [] diff --git a/brainpy/errors.py b/brainpy/errors.py index e59bb326c..37d4b9488 100644 --- a/brainpy/errors.py +++ b/brainpy/errors.py @@ -38,7 +38,16 @@ class AnalyzerError(BrainPyError): class PackageMissingError(BrainPyError): """The package missing error. """ - pass + + def __init__(self, name: str = None, purpose: str = None): + + if name is None: + super().__init__() + else: + assert purpose, '"purpose" cannot be None when "name" is provided.' + msg = (f'"{name}" must be installed when the user wants to use {purpose}. \n' + f'Please install through "pip install {name}".') + super().__init__(msg) class BackendNotInstalled(BrainPyError): @@ -236,9 +245,5 @@ def __init__(self, name): ''') - - class SharedArgError(BrainPyError): pass - - diff --git a/brainpy/math/event.py b/brainpy/math/event.py index 0a17cae7c..43d89c1b2 100644 --- a/brainpy/math/event.py +++ b/brainpy/math/event.py @@ -1,5 +1,4 @@ from brainpy._src.math.event import ( csrmv as csrmv, - info as info, ) diff --git a/brainpy/math/sparse.py b/brainpy/math/sparse.py index 1380a9e9c..fbe0acbf2 100644 --- a/brainpy/math/sparse.py +++ b/brainpy/math/sparse.py @@ -1,6 +1,5 @@ from brainpy._src.math.sparse import ( csrmv, - coomv, seg_matmul, diff --git a/requirements.txt b/requirements.txt index 02fdebe83..ab5665e73 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ numpy jax tqdm -numba -taichi==1.7.0 diff --git a/setup.py b/setup.py index d7fd45e38..766cd8c75 100644 --- a/setup.py +++ b/setup.py @@ -57,7 +57,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.8', - install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm', 'numba', 'taichi==1.7.0'], + install_requires=['numpy>=1.15', 'jax>=0.4.13', 'tqdm'], url='https://github.com/brainpy/BrainPy', project_urls={ "Bug Tracker": "https://github.com/brainpy/BrainPy/issues", @@ -68,11 +68,10 @@ 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html', ], extras_require={ - 'cpu': ['jaxlib>=0.4.13', 'brainpylib'], - 'cuda': ['jax[cuda]', 'brainpylib-cu12x'], - 'cuda11': ['jax[cuda11_local]', 'brainpylib-cu11x'], - 'cuda12': ['jax[cuda12_local]', 'brainpylib-cu12x'], - 'tpu': ['jax[tpu]'], + 'cpu': ['jaxlib>=0.4.13', 'brainpylib', 'numba', 'taichi==1.7.0'], + 'cuda11': ['jax[cuda11_pip]', 'brainpylib-cu11x', 'numba', 'taichi==1.7.0'], + 'cuda12': ['jax[cuda12_pip]', 'brainpylib-cu12x', 'numba', 'taichi==1.7.0'], + 'tpu': ['jax[tpu]', 'numba', 'taichi==1.7.0'], }, keywords=('computational neuroscience, ' 'brain-inspired computation, '