diff --git a/brainpy/_src/math/event/_csr_matvec.py b/brainpy/_src/math/event/_csr_matvec.py index 2e7895334..6e03be463 100644 --- a/brainpy/_src/math/event/_csr_matvec.py +++ b/brainpy/_src/math/event/_csr_matvec.py @@ -352,7 +352,7 @@ def _f(ct, indices, indptr, events, *, transpose): event_csr_matvec_batching_p = Primitive('event_csr_matvec_batching') event_csr_matvec_batching_p.def_abstract_eval(_batch_event_csr_matvec_abstract) event_csr_matvec_batching_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_batching_p)) -xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation +# xla.backend_specific_translations['cpu'][event_csr_matvec_batching_p] = _batch_event_csr_matvec_cpu_translation ad.defjvp(event_csr_matvec_batching_p, _batch_event_csr_matvec_jvp_values, None, None, _batch_event_csr_matvec_jvp_events) ad.primitive_transposes[event_csr_matvec_batching_p] = _batch_event_csr_matvec_transpose @@ -597,8 +597,8 @@ def _event_csr_matvec_transpose_brainpylib(ct, values, indices, indptr, events, event_csr_matvec_p = Primitive('event_csr_matvec') event_csr_matvec_p.def_abstract_eval(_event_csr_matvec_abstract) event_csr_matvec_p.def_impl(partial(xla.apply_primitive, event_csr_matvec_p)) -xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation -xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation +# xla.backend_specific_translations['cpu'][event_csr_matvec_p] = _event_csr_matvec_cpu_translation +# xla.backend_specific_translations['gpu'][event_csr_matvec_p] = _event_csr_matvec_gpu_translation ad.defjvp(event_csr_matvec_p, _event_csr_matvec_jvp_values_brainpylib, None, None, _event_csr_matvec_jvp_events_brainpylib) ad.primitive_transposes[event_csr_matvec_p] = _event_csr_matvec_transpose_brainpylib diff --git a/brainpy/_src/math/event/_info_collection.py b/brainpy/_src/math/event/_info_collection.py index 9f8a5f31a..5f6acbb09 100644 --- a/brainpy/_src/math/event/_info_collection.py +++ b/brainpy/_src/math/event/_info_collection.py @@ -6,15 +6,16 @@ import numba from jax import dtypes, numpy as jnp from jax.core import ShapedArray -from jax.interpreters import batching from jax.lib import xla_client +from brainpy._src.dependency_check import import_brainpylib_gpu_ops +from brainpy._src.dependency_check import import_taichi from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import register_op_with_numba from brainpy._src.math.ndarray import Array -from brainpy._src.dependency_check import import_brainpylib_gpu_ops +from brainpy._src.math.op_register.base import XLACustomOp from brainpy.errors import GPUOperatorNotFound +ti = import_taichi() __all__ = [ 'info' @@ -40,7 +41,7 @@ def info(events: Union[Array, jax.Array]) -> Tuple[jax.Array, jax.Array]: events = as_jax(events) if events.ndim != 1: raise TypeError('Only support 1D boolean vector.') - return event_info_p.bind(events) + return event_info_p(events) def _batch_event_info_abstract(events): @@ -66,11 +67,26 @@ def _batch_event_info(outs, ins): event_num[batch_idx] = num +@ti.kernel +def _batch_event_info_taichi(events: ti.types.ndarray(ndim=2), + event_ids: ti.types.ndarray(ndim=2), + event_num: ti.types.ndarray(ndim=1)): + for i, j in ti.grouped(ti.ndrange(event_ids.shape)): + event_ids[i, j] = -1 + for batch_idx in range(event_ids.shape[0]): + num = 0 + for i in range(event_ids.shape[1]): + if events[batch_idx, i]: + event_ids[batch_idx, num] = i + num += 1 + event_num[batch_idx] = num + + def _batch_event_info_batching_rule(args, axes): arg = jnp.moveaxis(args[0], axes[0], 0) shape = arg.shape arg = jnp.reshape(arg, (shape[0] * shape[1], shape[2])) - event_ids, event_num = batch_event_info_p.bind(arg) + event_ids, event_num = batch_event_info_p(arg) return ((jnp.reshape(event_ids, shape), jnp.reshape(event_num, shape[:2])), (0, 0)) @@ -121,17 +137,15 @@ def _event_info_gpu_translation(c, events): ) -batch_event_info_p = register_op_with_numba( - op_name='event_info', - cpu_func=_batch_event_info, - out_shapes=_batch_event_info_abstract, - gpu_func_translation=_event_info_gpu_translation, - multiple_results=True +batch_event_info_p = XLACustomOp( + name='batched_event_info', + cpu_kernel=_batch_event_info_taichi, + outs=_batch_event_info_abstract, ) -batching.primitive_batchers[batch_event_info_p] = _batch_event_info_batching_rule +batch_event_info_p.def_batching_rule(_batch_event_info_batching_rule) -def _event_info_abstract(events): +def _event_info_abstract(events, **kwargs): assert events.ndim == 1 # assert events.dtype == jnp.bool_ event_ids = ShapedArray(dtype=dtypes.canonicalize_dtype(int), shape=events.shape) @@ -140,7 +154,7 @@ def _event_info_abstract(events): # TODO: first parallel evaluate the sub-sections, then serially event the sub-results. -@numba.njit(fastmath=True) +@numba.jit(fastmath=True) def _event_info(outs, ins): event_ids, event_num = outs event_num.fill(0) @@ -154,16 +168,29 @@ def _event_info(outs, ins): event_num[0] = num +@ti.kernel +def _event_info_taichi(events: ti.types.ndarray(ndim=1), + event_ids: ti.types.ndarray(ndim=1), + event_num: ti.types.ndarray(ndim=1)): + for i in range(event_ids.shape[0]): + event_ids[i] = -1 + num = 0 + for i in range(event_ids.shape[0]): + if events[i]: + event_ids[num] = i + num += 1 + event_num[0] = num + + def _event_info_batching_rule(args, axes): arg = jnp.moveaxis(args[0], axes[0], 0) - return (batch_event_info_p.bind(arg), (0, 0)) + return (batch_event_info_p(arg), (0, 0)) -event_info_p = register_op_with_numba( - op_name='event_info', - cpu_func=_event_info, - out_shapes=_event_info_abstract, - gpu_func_translation=_event_info_gpu_translation, - multiple_results=True +event_info_p = XLACustomOp( + name='event_info', + cpu_kernel=_event_info_taichi, + outs=_event_info_abstract, + # gpu_func_translation=_event_info_gpu_translation, ) -batching.primitive_batchers[event_info_p] = _event_info_batching_rule +event_info_p.def_batching_rule(_event_info_batching_rule) diff --git a/brainpy/_src/math/jitconn/_event_matvec.py b/brainpy/_src/math/jitconn/_event_matvec.py index 7971b4a92..3671755a9 100644 --- a/brainpy/_src/math/jitconn/_event_matvec.py +++ b/brainpy/_src/math/jitconn/_event_matvec.py @@ -358,8 +358,8 @@ def _event_matvec_prob_homo_transpose( event_mv_prob_homo_p.multiple_results = True event_mv_prob_homo_p.def_abstract_eval(_event_matvec_prob_homo_abstract) event_mv_prob_homo_p.def_impl(partial(xla.apply_primitive, event_mv_prob_homo_p)) -xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation -xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation +# xla.backend_specific_translations['cpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_homo_p] = _event_matvec_prob_homo_gpu_translation ad.primitive_jvps[event_mv_prob_homo_p] = _event_matvec_prob_homo_jvp ad.primitive_transposes[event_mv_prob_homo_p] = _event_matvec_prob_homo_transpose register_general_batching(event_mv_prob_homo_p) @@ -529,8 +529,8 @@ def _event_matvec_prob_uniform_transpose( event_mv_prob_uniform_p.multiple_results = True event_mv_prob_uniform_p.def_abstract_eval(_event_matvec_prob_uniform_abstract) event_mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, event_mv_prob_uniform_p)) -xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation -xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation +# xla.backend_specific_translations['cpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_uniform_p] = _event_matvec_prob_uniform_gpu_translation register_general_batching(event_mv_prob_uniform_p) ad.primitive_jvps[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_jvp ad.primitive_transposes[event_mv_prob_uniform_p] = _event_matvec_prob_uniform_transpose @@ -723,8 +723,8 @@ def _event_matvec_prob_normal_transpose( event_mv_prob_normal_p.multiple_results = True event_mv_prob_normal_p.def_abstract_eval(_event_matvec_prob_normal_abstract) event_mv_prob_normal_p.def_impl(partial(xla.apply_primitive, event_mv_prob_normal_p)) -xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation -xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation +# xla.backend_specific_translations['cpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_cpu_translation +# xla.backend_specific_translations['gpu'][event_mv_prob_normal_p] = _event_matvec_prob_normal_gpu_translation register_general_batching(event_mv_prob_normal_p) ad.primitive_jvps[event_mv_prob_normal_p] = _event_matvec_prob_normal_jvp ad.primitive_transposes[event_mv_prob_normal_p] = _event_matvec_prob_normal_transpose diff --git a/brainpy/_src/math/jitconn/_matvec.py b/brainpy/_src/math/jitconn/_matvec.py index e33a0ab1e..0caa9c996 100644 --- a/brainpy/_src/math/jitconn/_matvec.py +++ b/brainpy/_src/math/jitconn/_matvec.py @@ -640,8 +640,8 @@ def _matvec_prob_homo_transpose( mv_prob_homo_p.multiple_results = True mv_prob_homo_p.def_abstract_eval(_matvec_prob_homo_abstract) mv_prob_homo_p.def_impl(partial(xla.apply_primitive, mv_prob_homo_p)) -xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation -xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation +# xla.backend_specific_translations['cpu'][mv_prob_homo_p] = _matvec_prob_homo_cpu_translation +# xla.backend_specific_translations['gpu'][mv_prob_homo_p] = _matvec_prob_homo_gpu_translation register_general_batching(mv_prob_homo_p) ad.primitive_jvps[mv_prob_homo_p] = _matvec_prob_homo_jvp ad.primitive_transposes[mv_prob_homo_p] = _matvec_prob_homo_transpose @@ -823,8 +823,8 @@ def _matvec_prob_uniform_transpose( mv_prob_uniform_p.multiple_results = True mv_prob_uniform_p.def_abstract_eval(_matvec_prob_uniform_abstract) mv_prob_uniform_p.def_impl(partial(xla.apply_primitive, mv_prob_uniform_p)) -xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation -xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation +# xla.backend_specific_translations['cpu'][mv_prob_uniform_p] = _matvec_prob_uniform_cpu_translation +# xla.backend_specific_translations['gpu'][mv_prob_uniform_p] = _matvec_prob_uniform_gpu_translation register_general_batching(mv_prob_uniform_p) ad.primitive_jvps[mv_prob_uniform_p] = _matvec_prob_uniform_jvp ad.primitive_transposes[mv_prob_uniform_p] = _matvec_prob_uniform_transpose @@ -1009,8 +1009,8 @@ def _matvec_prob_normal_transpose( mv_prob_normal_p.multiple_results = True mv_prob_normal_p.def_abstract_eval(_matvec_prob_normal_abstract) mv_prob_normal_p.def_impl(partial(xla.apply_primitive, mv_prob_normal_p)) -xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation -xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation +# xla.backend_specific_translations['cpu'][mv_prob_normal_p] = _matvec_prob_normal_cpu_translation +# xla.backend_specific_translations['gpu'][mv_prob_normal_p] = _matvec_prob_normal_gpu_translation register_general_batching(mv_prob_normal_p) ad.primitive_jvps[mv_prob_normal_p] = _matvec_prob_normal_jvp ad.primitive_transposes[mv_prob_normal_p] = _matvec_prob_normal_transpose diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 61746c038..cf2b2343d 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -9,9 +9,11 @@ from jax.dtypes import canonicalize_dtype from jax.tree_util import register_pytree_node_class -import brainpy.math from brainpy.errors import MathError +bm = None + + __all__ = [ 'Array', 'ndarray', 'JaxArray', # alias of Array 'ShardedArray', @@ -1039,7 +1041,9 @@ def __jax_array__(self): def as_variable(self): """As an instance of Variable.""" - return brainpy.math.Variable(self) + global bm + if bm is None: from brainpy import math as bm + return bm.Variable(self) def __format__(self, specification): return self.value.__format__(specification) @@ -1473,7 +1477,9 @@ def fill_(self, value): return self def uniform_(self, low=0., high=1.): - self.value = brainpy.math.random.uniform(low, high, self.shape) + global bm + if bm is None: from brainpy import math as bm + self.value = bm.random.uniform(low, high, self.shape) return self def log_normal_(self, mean=1, std=2): @@ -1489,14 +1495,18 @@ def log_normal_(self, mean=1, std=2): mean: the mean value. std: the standard deviation. """ - self.value = brainpy.math.random.lognormal(mean, std, self.shape) + global bm + if bm is None: from brainpy import math as bm + self.value = bm.random.lognormal(mean, std, self.shape) return self def normal_(self, ): """ Fills self tensor with elements samples from the normal distribution parameterized by mean and std. """ - self.value = brainpy.math.random.randn(*self.shape) + global bm + if bm is None: from brainpy import math as bm + self.value = bm.random.randn(*self.shape) return self def cuda(self): diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 353892178..032a0fab6 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -732,6 +732,7 @@ def _get_for_loop_transform( unroll: int, unroll_kwargs: tools.DotDict ): + @functools.wraps(body_fun) def fun2scan(carry, x): for k in dyn_vars.keys(): dyn_vars[k]._value = carry[k] @@ -912,6 +913,7 @@ def for_loop( dyn_vars[key]._value = dyn_vals[key] if progress_bar: bar.close() + del dyn_vals, dyn_vars return out_vals diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index bc5f4c15a..1824ac911 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -8,18 +8,18 @@ from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject -# if jax.__version__ >= '0.4.16': -# from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule -# else: -# from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule -from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule -from .taichi_aot_based import (register_taichi_cpu_translation_rule, - register_taichi_gpu_translation_rule, - clean_caches) + +if jax.__version__ >= '0.4.16': + from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule + from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, + register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) +else: + from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule + from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule, + register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule) from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp - __all__ = [ 'XLACustomOp', ] @@ -64,8 +64,8 @@ class XLACustomOp(BrainPyObject): >>> >>> # option 2 >>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun, - >>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32), - >>> jax.ShapeDtypeStruct(1000, dtype=np.float32)]) + >>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype), + >>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)]) >>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000)) Args: @@ -74,7 +74,7 @@ class XLACustomOp(BrainPyObject): batching_translation: Callable. The batching translation rule of JAX. jvp_translation: Callable. The JVP translation rule of JAX. transpose_translation: Callable. The transpose translation rule of JAX. - outs: optional, sequence of `ShapeDtype`. The output information. + outs: optional. The output information. name: str. The primitive name. """ @@ -85,7 +85,7 @@ def __init__( batching_translation: Callable = None, jvp_translation: Callable = None, transpose_translation: Callable = None, - outs: Optional[Sequence[ShapeDtype]] = None, + outs: Optional[Callable] = None, name: str = None, ): super().__init__(name) @@ -99,8 +99,6 @@ def __init__( self.primitive.multiple_results = True # abstract evaluation - if outs is not None: - outs = tuple([_transform_to_shapedarray(o) for o in outs]) self.outs = outs self.primitive.def_abstract_eval(_abstract_eval) self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) @@ -139,10 +137,11 @@ def __init__( if transpose_translation is not None: ad.primitive_transposes[self.primitive] = transpose_translation - def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs): if outs is None: - outs = self.outs + if self.outs is None: + raise ValueError('The output information is not defined.') + outs = self.outs(*ins, **kwargs) assert outs is not None outs = tuple([_transform_to_shapedarray(o) for o in outs]) ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array) @@ -227,5 +226,3 @@ def _transform_to_array(a): def _transform_to_shapedarray(a): return jax.core.ShapedArray(a.shape, a.dtype) - - diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 76362215e..cc2ce5b4c 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -6,7 +6,7 @@ from typing import Union, Sequence import numba -from jax import core +import jax from jax.interpreters import xla, batching, ad from jax.tree_util import tree_map from numba.core.dispatcher import Dispatcher @@ -40,8 +40,8 @@ class CustomOpByNumba(BrainPyObject): The function to make the concrete computation. This function receives inputs, and returns outputs. For example: - >>> def con_compute(inp1, inp2, inp3, ...): - >>> return out1, out2 + >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...): + >>> pass """ def __init__( @@ -86,7 +86,7 @@ def __call__(self, *args, **kwargs): def register_op_with_numba( op_name: str, cpu_func: Callable, - out_shapes: Union[Callable, core.ShapedArray, Sequence[core.ShapedArray]], + out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], gpu_func_translation: Callable = None, batching_translation: Callable = None, jvp_translation: Callable = None, @@ -130,12 +130,19 @@ def register_op_with_numba( A JAX Primitive object. """ + if jax.__version__ > '0.4.23': + raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are ' + f'only supported in JAX version <= 0.4.23. \n' + f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. ' + f'For more information, please refer to the documentation: ' + f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.') + if out_shapes is None: raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' 'shapes and dtypes and should return correct output shapes of `ShapedArray`.') - prim = core.Primitive(op_name) + prim = jax.core.Primitive(op_name) prim.multiple_results = multiple_results # user defined function @@ -149,12 +156,12 @@ def abs_eval_rule(*input_shapes, **info): else: shapes = out_shapes - if isinstance(shapes, core.ShapedArray): + if isinstance(shapes, jax.core.ShapedArray): assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." elif isinstance(shapes, (tuple, list)): assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." for elem in shapes: - if not isinstance(elem, core.ShapedArray): + if not isinstance(elem, jax.core.ShapedArray): raise ValueError(f'Elements in "out_shapes" must be instances of ' f'jax.abstract_arrays.ShapedArray, but we got ' f'{type(elem)}: {elem}') diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py index fb51b5dbf..fb76aed24 100644 --- a/brainpy/_src/math/op_register/numba_based.py +++ b/brainpy/_src/math/op_register/numba_based.py @@ -16,6 +16,10 @@ 'register_numba_mlir_cpu_translation_rule', ] + +# [void* pointer, +# const char *name, +# PyCapsule_Destructor destructor] ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object @@ -100,6 +104,7 @@ def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs): def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False): + # do not support after jax >= 0.4.24 xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule, cpu_kernel, debug) @@ -124,38 +129,44 @@ def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs): output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' for i in range(len(input_shapes))] - args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' - for i in range(len(output_shapes))] + if len(output_shapes) > 1: + args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' + for i in range(len(output_shapes))] + sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) + else: + args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'] + sig = types.void(types.voidptr, types.CPointer(types.voidptr)) args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))] code_string = ''' - def numba_cpu_custom_call_target(output_ptrs, input_ptrs): +def numba_cpu_custom_call_target(output_ptrs, input_ptrs): {args_in} {args_out} func_to_call({args_call}) '''.format(args_in="\n ".join(args_in), args_out="\n ".join(args_out), args_call=", ".join(args_call)) - if debug: print(code_string) + if debug: + print(code_string) exec(compile(code_string.strip(), '', 'exec'), code_scope) new_f = code_scope['numba_cpu_custom_call_target'] # register - xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)))(new_f) + xla_c_rule = cfunc(sig)(new_f) target_name = f'numba_custom_call_{str(xla_c_rule.address)}' capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) xla_client.register_custom_call_target(target_name, capsule, "cpu") # call - call = custom_call(call_target_name=target_name, - operands=list(ins), - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types)).results - return call + return custom_call( + call_target_name=target_name, + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + has_side_effect=False, + ).results def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False): rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug) mlir.register_lowering(primitive, rule, platform='cpu') - - diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index dda5d5799..7fac4452d 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -12,10 +12,13 @@ import jax.core import numpy as np -from jax.interpreters import xla +from jax.interpreters import xla, mlir from jax.lib import xla_client +from jaxlib.hlo_helpers import custom_call -from brainpy._src.dependency_check import import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops +from brainpy._src.dependency_check import (import_taichi, + import_brainpylib_cpu_ops, + import_brainpylib_gpu_ops) from .utils import _shape_to_layout @@ -330,11 +333,18 @@ def _preprocess_kernel_call_gpu( return opaque + + + def _XlaOp_to_ShapedArray(c, xla_op): xla_op = c.get_shape(xla_op) return jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type()) +def _mlir_to_ShapedArray(c, op): + return op + + def _kernel_to_code(kernel, abs_ins, abs_outs, platform): codes = f'[taichi {platform} kernel]\n' + get_source_with_dependencies(kernel) codes += '\n[ins]: {}'.format("-".join([f'{v.dtype}[{v.shape}]' for v in abs_ins])) @@ -342,17 +352,16 @@ def _kernel_to_code(kernel, abs_ins, abs_outs, platform): return codes -def _compile_kernel(kernel, c, platform, *ins, **kwargs): +def _compile_kernel(abs_ins, kernel, platform: str, **kwargs): # input and output abstract information abs_outs = kwargs['outs'] - abs_ins = [_XlaOp_to_ShapedArray(c, v) for v in ins] # kernel to code codes = _kernel_to_code(kernel, abs_ins, abs_outs, platform) source_md5_encode = os.path.join(kernel.__name__, encode_md5(codes)) # create ins, outs dict from kernel's args - in_num = len(ins) + in_num = len(abs_ins) names = tuple(inspect.signature(kernel).parameters.keys()) in_names, out_names = names[:in_num], names[in_num:] ins_dict = {key: (abs_ins[i].dtype, abs_ins[i].shape) for i, key in enumerate(in_names)} @@ -382,8 +391,16 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): raise ValueError(f'Unknown platform: {platform}') -def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs): - in_out_info = _compile_kernel(kernel, c, 'cpu', *ins, **kwargs) +def _get_abs_ins(c, ins): + abs_ins = [] + for v in ins: + xla_op = c.get_shape(v) + abs_ins.append(jax.core.ShapedArray(xla_op.dimensions(), xla_op.element_type())) + return abs_ins + + +def _taichi_xla_cpu_translation_rule(kernel, c, *ins, **kwargs): + in_out_info = _compile_kernel(_get_abs_ins(c, ins), kernel, 'cpu', **kwargs) ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins) if is_metal_device: fn = b'taichi_kernel_aot_call_cpu_arm64' @@ -402,8 +419,8 @@ def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs): ) -def _taichi_gpu_translation_rule(kernel, c, *ins, **kwargs): - opaque = _compile_kernel(kernel, c, 'gpu', *ins, **kwargs) +def _taichi_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): + opaque = _compile_kernel(_get_abs_ins(c, ins), kernel, 'gpu', **kwargs) return xla_client.ops.CustomCallWithLayout( c, b'taichi_kernel_aot_call_gpu', @@ -417,9 +434,61 @@ def _taichi_gpu_translation_rule(kernel, c, *ins, **kwargs): ) -def register_taichi_cpu_translation_rule(primitive, cpu_kernel): - xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_cpu_translation_rule, cpu_kernel) +def register_taichi_aot_xla_cpu_translation_rule(primitive, cpu_kernel): + xla.backend_specific_translations['cpu'][primitive] = partial(_taichi_xla_cpu_translation_rule, cpu_kernel) + + +def register_taichi_aot_xla_gpu_translation_rule(primitive, gpu_kernel): + xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_xla_gpu_translation_rule, gpu_kernel) + + +def _taichi_mlir_cpu_translation_rule(kernel, c, *ins, **kwargs): + in_out_info = _compile_kernel(c.avals_in, kernel, 'cpu', **kwargs) + ins = [mlir.ir_constant(v) for v in in_out_info] + list(ins) + input_layouts = [_shape_to_layout(arr.shape) for arr in in_out_info] + [_shape_to_layout(a.shape) for a in c.avals_in] + output_layouts = tuple([_shape_to_layout(out.shape) for out in c.avals_out]) + result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] + if is_metal_device: + if len(output_layouts) == 1: + fn = 'taichi_kernel_aot_call_cpu_arm64_single_result' + else: + fn = 'taichi_kernel_aot_call_cpu_arm64' + else: + if len(output_layouts) == 1: + fn = 'taichi_kernel_aot_call_cpu_single_result' + else: + fn = 'taichi_kernel_aot_call_cpu' + return custom_call( + call_target_name=fn, + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + has_side_effect=False, + ).results + + +def _taichi_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): + opaque = _compile_kernel(c.avals_in, kernel, 'gpu', **kwargs) + input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] + result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] + output_layouts = [_shape_to_layout(out.shape) for out in c.avals_out] + return custom_call( + call_target_name='taichi_kernel_aot_call_gpu', + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + backend_config=opaque, + has_side_effect=False, + ).results + + +def register_taichi_aot_mlir_cpu_translation_rule(primitive, cpu_kernel): + rule = partial(_taichi_mlir_cpu_translation_rule, cpu_kernel) + mlir.register_lowering(primitive, rule, platform='cpu') -def register_taichi_gpu_translation_rule(primitive, gpu_kernel): - xla.backend_specific_translations['gpu'][primitive] = partial(_taichi_gpu_translation_rule, gpu_kernel) +def register_taichi_aot_mlir_gpu_translation_rule(primitive, gpu_kernel): + rule = partial(_taichi_mlir_gpu_translation_rule, gpu_kernel) + mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py index 547bbdc7c..5a9343642 100644 --- a/brainpy/_src/math/op_register/tests/test_ad_support.py +++ b/brainpy/_src/math/op_register/tests/test_ad_support.py @@ -1,136 +1,136 @@ -from typing import Tuple - -import jax -import numba -from jax import core -from jax import numpy as jnp -from jax.interpreters import ad - -import brainpy as bp -import brainpy.math as bm - - -def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ): - data = jnp.atleast_1d(bm.as_jax(data)) - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = bm.as_jax(vector) - if vector.dtype == jnp.bool_: - vector = bm.as_jax(vector, dtype=data.dtype) - outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)] - if transpose: - return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose) - else: - return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose) - - -@numba.njit(fastmath=True) -def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val): - res_val.fill(0) - if values.shape[0] == 1: - values = values[0] - for row_i in range(vector.shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += values * v - else: - for row_i in range(vector.shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += v * values[j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val): - res_val.fill(0) - # csr mat @ vec - if values.shape[0] == 1: - values = values[0] - for row_i in numba.prange(res_val.shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values * vector[col_indices[j]] - res_val[row_i] = r - else: - for row_i in numba.prange(res_val.shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - res_val[row_i] = r - - -def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs): - return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose) - - -def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs): - return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose) - - -def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - ct = ct[0] - if ad.is_undefined_primal(vector): - ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = bm.sparse.csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp) -prim_trans.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec) -prim_trans.def_transpose_rule(_csrmv_cusparse_transpose) - -prim = bm.XLACustomOp(_csr_matvec_numba_imp) -prim.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec) -prim.def_transpose_rule(_csrmv_cusparse_transpose) - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -def try_a_trial(transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs, method='vector')), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - print(r5) - print(r6) - assert bm.allclose(r5[0], r6[0]) - assert bm.allclose(r5[1], r6[1][0]) - - -def test(): - transposes = [True, False] - shapes = [(100, 200), (10, 1000), (2, 2000)] - - for transpose in transposes: - for shape in shapes: - try_a_trial(transpose, shape) +from typing import Tuple + +import jax +import numba +from jax import core +from jax import numpy as jnp +from jax.interpreters import ad + +import brainpy as bp +import brainpy.math as bm + + +def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ): + data = jnp.atleast_1d(bm.as_jax(data)) + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + vector = bm.as_jax(vector) + if vector.dtype == jnp.bool_: + vector = bm.as_jax(vector, dtype=data.dtype) + outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)] + if transpose: + return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose) + else: + return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose) + + +@numba.njit(fastmath=True) +def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val): + res_val.fill(0) + if values.shape[0] == 1: + values = values[0] + for row_i in range(vector.shape[0]): + v = vector[row_i] + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + res_val[col_indices[j]] += values * v + else: + for row_i in range(vector.shape[0]): + v = vector[row_i] + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + res_val[col_indices[j]] += v * values[j] + + +@numba.njit(fastmath=True, parallel=True, nogil=True) +def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val): + res_val.fill(0) + # csr mat @ vec + if values.shape[0] == 1: + values = values[0] + for row_i in numba.prange(res_val.shape[0]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values * vector[col_indices[j]] + res_val[row_i] = r + else: + for row_i in numba.prange(res_val.shape[0]): + r = 0. + for j in range(row_ptr[row_i], row_ptr[row_i + 1]): + r += values[j] * vector[col_indices[j]] + res_val[row_i] = r + + +def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs): + return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose) + + +def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs): + return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose) + + +def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs): + if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): + raise ValueError("Cannot transpose with respect to sparse indices.") + + ct = ct[0] + if ad.is_undefined_primal(vector): + ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose) + return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) + + else: + if type(ct) is ad.Zero: + ct_data = ad.Zero(data) + else: + if data.aval.shape[0] == 1: # scalar + ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) + ct_data = jnp.inner(ct, ct_data) + else: # heterogeneous values + row, col = bm.sparse.csr_to_coo(indices, indptr) + ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] + return ct_data, indices, indptr, vector + + +prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp) +prim_trans.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec) +prim_trans.def_transpose_rule(_csrmv_cusparse_transpose) + +prim = bm.XLACustomOp(_csr_matvec_numba_imp) +prim.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec) +prim.def_transpose_rule(_csrmv_cusparse_transpose) + + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +def try_a_trial(transpose, shape): + rng = bm.random.RandomState() + conn = bp.conn.FixedProb(0.1) + indices, indptr = conn(*shape).require('pre2post') + indices = bm.as_jax(indices) + indptr = bm.as_jax(indptr) + heter_data = rng.random(indices.shape) + heter_data = bm.as_jax(heter_data) + vector = rng.random(shape[0] if transpose else shape[1]) + vector = bm.as_jax(vector) + + r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs)), argnums=(0, 3))( + heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))( + heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + print(r5) + print(r6) + assert bm.allclose(r5[0], r6[0]) + assert bm.allclose(r5[1], r6[1][0]) + + +def test(): + transposes = [True, False] + shapes = [(100, 200), (10, 1000), (2, 2000)] + + for transpose in transposes: + for shape in shapes: + try_a_trial(transpose, shape) diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 7f405ec12..03023754c 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -30,8 +30,19 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) +@ti.kernel +def event_ell_gpu(indices: ti.types.ndarray(ndim=2), + vector: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=1), + out: ti.types.ndarray(ndim=1)): + weight_val = get_weight(weight) + num_rows, num_cols = indices.shape + for i in range(num_rows): + if vector[i]: + for j in range(num_cols): + update_output(out, indices[i, j], weight_val) -prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) +prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu) def test_taichi_op_register(): diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index 0acd2010b..453ab387d 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -404,8 +404,8 @@ def _bcsrmm_cutlass_jvp_transpose(): _bcsrmm_cutlass_p.multiple_results = True _bcsrmm_cutlass_p.def_abstract_eval(_bcsrmm_cutlass_abstract) _bcsrmm_cutlass_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_p)) -xla.backend_specific_translations['cpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_cpu_translation -xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_gpu_translation +# xla.backend_specific_translations['cpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_cpu_translation +# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_gpu_translation ad.primitive_jvps[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose ad.primitive_transposes[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose register_general_batching(bcsrmm) @@ -456,5 +456,5 @@ def _blocksparse_matmat_back_gpu_translation( _bcsrmm_cutlass_back_p.multiple_results = True _bcsrmm_cutlass_back_p.def_abstract_eval(_blocksparse_matmat_back_abstract) _bcsrmm_cutlass_back_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_back_p)) -xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_back_p] = _blocksparse_matmat_back_gpu_translation +# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_back_p] = _blocksparse_matmat_back_gpu_translation register_general_batching(_bcsrmm_cutlass_back_p) diff --git a/brainpy/_src/math/sparse/_bsr_mv.py b/brainpy/_src/math/sparse/_bsr_mv.py index 76d1715e0..a35895bc1 100644 --- a/brainpy/_src/math/sparse/_bsr_mv.py +++ b/brainpy/_src/math/sparse/_bsr_mv.py @@ -202,8 +202,8 @@ def _cusparse_bcsr_transpose(ct, data, indices, indptr, vector, *, blocksize, sh cusparse_bcsr_matvec_vector_p = Primitive('cusparse_block_spmv') cusparse_bcsr_matvec_vector_p.def_abstract_eval(_cusparse_bcsr_matvec_abstract) cusparse_bcsr_matvec_vector_p.def_impl(partial(xla.apply_primitive, cusparse_bcsr_matvec_vector_p)) -xla.backend_specific_translations['gpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_gpu_translation -xla.backend_specific_translations['cpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_cpu_translation +# xla.backend_specific_translations['gpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_gpu_translation +# xla.backend_specific_translations['cpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_cpu_translation ad.defjvp(cusparse_bcsr_matvec_vector_p, _cusparse_bcsr_matvec_jvp_values) ad.primitive_transposes[cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_transpose register_general_batching(cusparse_bcsr_matvec_vector_p) diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index 47704af04..377597579 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -301,7 +301,7 @@ def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, trans _csrmv_cusparse_p = core.Primitive('cusparse_csr_matvec') _csrmv_cusparse_p.def_abstract_eval(_csrmv_abstract) _csrmv_cusparse_p.def_impl(partial(xla.apply_primitive, _csrmv_cusparse_p)) -xla.backend_specific_translations['cpu'][_csrmv_cusparse_p] = _csrmv_cpu_translation +# xla.backend_specific_translations['cpu'][_csrmv_cusparse_p] = _csrmv_cpu_translation ad.defjvp(_csrmv_cusparse_p, partial(_csrmv_jvp_mat, _csrmv_cusparse_p), None, @@ -372,8 +372,8 @@ def _csrmv_scalar_transpose(ct, data, indices, indptr, vector, *, shape, transpo _csrmv_scalar_p = core.Primitive('csr_matvec_scalar') _csrmv_scalar_p.def_abstract_eval(_csrmv_abstract) _csrmv_scalar_p.def_impl(partial(xla.apply_primitive, _csrmv_scalar_p)) -xla.backend_specific_translations['cpu'][_csrmv_scalar_p] = _csrmv_cpu_translation -xla.backend_specific_translations['gpu'][_csrmv_scalar_p] = _csr_matvec_scalar_gpu_translation +# xla.backend_specific_translations['cpu'][_csrmv_scalar_p] = _csrmv_cpu_translation +# xla.backend_specific_translations['gpu'][_csrmv_scalar_p] = _csr_matvec_scalar_gpu_translation ad.defjvp(_csrmv_scalar_p, partial(_csrmv_jvp_mat, _csrmv_scalar_p), None, @@ -443,8 +443,8 @@ def _csrmv_vector_transpose(ct, data, indices, indptr, vector, *, shape, transpo _csrmv_vector_p = core.Primitive('csr_matvec_vector') _csrmv_vector_p.def_abstract_eval(_csrmv_abstract) _csrmv_vector_p.def_impl(partial(xla.apply_primitive, _csrmv_vector_p)) -xla.backend_specific_translations['cpu'][_csrmv_vector_p] = _csrmv_cpu_translation -xla.backend_specific_translations['gpu'][_csrmv_vector_p] = _csr_matvec_vector_gpu_translation +# xla.backend_specific_translations['cpu'][_csrmv_vector_p] = _csrmv_cpu_translation +# xla.backend_specific_translations['gpu'][_csrmv_vector_p] = _csr_matvec_vector_gpu_translation ad.defjvp(_csrmv_vector_p, partial(_csrmv_jvp_mat, _csrmv_vector_p), None, @@ -515,8 +515,8 @@ def _csrmv_adaptive_transpose(ct, data, indices, indptr, vector, *, shape, trans _csrmv_adaptive_p = core.Primitive('csr_matvec_adaptive') _csrmv_adaptive_p.def_abstract_eval(_csrmv_abstract) _csrmv_adaptive_p.def_impl(partial(xla.apply_primitive, _csrmv_adaptive_p)) -xla.backend_specific_translations['cpu'][_csrmv_adaptive_p] = _csrmv_cpu_translation -xla.backend_specific_translations['gpu'][_csrmv_adaptive_p] = _csr_matvec_adaptive_gpu_translation +# xla.backend_specific_translations['cpu'][_csrmv_adaptive_p] = _csrmv_cpu_translation +# xla.backend_specific_translations['gpu'][_csrmv_adaptive_p] = _csr_matvec_adaptive_gpu_translation ad.defjvp(_csrmv_adaptive_p, partial(_csrmv_jvp_mat, _csrmv_adaptive_p), None,