diff --git a/brainpy/_src/math/__init__.py b/brainpy/_src/math/__init__.py index 01159883..a28ba7d8 100644 --- a/brainpy/_src/math/__init__.py +++ b/brainpy/_src/math/__init__.py @@ -47,7 +47,6 @@ from . import random, linalg, fft # operators -from .op_register import * from .pre_syn_post import * from . import surrogate, event, sparse, jitconn diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 4d4dd25a..be4b19d1 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -4,14 +4,12 @@ import jax import numpy as np -from jax import numpy as jnp - +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found from brainpy._src.math import defaults from brainpy._src.math.interoperability import as_jax from brainpy._src.math.ndarray import Array -from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError -from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found +from jax import numpy as jnp bti = import_braintaichi(error_if_not_found=False) diff --git a/brainpy/_src/math/op_register/__init__.py b/brainpy/_src/math/op_register/__init__.py deleted file mode 100644 index 19160708..00000000 --- a/brainpy/_src/math/op_register/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from .numba_approach import (CustomOpByNumba, - register_op_with_numba_xla, - compile_cpu_signature_with_numba) -from .base import XLACustomOp -from .utils import register_general_batching -from .base import XLACustomOp -from .utils import register_general_batching diff --git a/brainpy/_src/math/op_register/ad_support.py b/brainpy/_src/math/op_register/ad_support.py deleted file mode 100644 index 54a3c9be..00000000 --- a/brainpy/_src/math/op_register/ad_support.py +++ /dev/null @@ -1,56 +0,0 @@ -import functools -from functools import partial - -from jax import tree_util -from jax.core import Primitive -from jax.interpreters import ad - -__all__ = [ - 'defjvp', -] - - -def defjvp(primitive, *jvp_rules): - """Define JVP rules for any JAX primitive. - - This function is similar to ``jax.interpreters.ad.defjvp``. - However, the JAX one only supports primitive with ``multiple_results=False``. - ``brainpy.math.defjvp`` enables to define the independent JVP rule for - each input parameter no matter ``multiple_results=False/True``. - - For examples, please see ``test_ad_support.py``. - - Args: - primitive: Primitive, XLACustomOp. - *jvp_rules: The JVP translation rule for each primal. - """ - assert isinstance(primitive, Primitive) - if primitive.multiple_results: - ad.primitive_jvps[primitive] = partial(_standard_jvp, jvp_rules, primitive) - else: - ad.primitive_jvps[primitive] = partial(ad.standard_jvp, jvp_rules, primitive) - - -def _standard_jvp(jvp_rules, primitive: Primitive, primals, tangents, **params): - assert primitive.multiple_results - val_out = tuple(primitive.bind(*primals, **params)) - tree = tree_util.tree_structure(val_out) - tangents_out = [] - for rule, t in zip(jvp_rules, tangents): - if rule is not None and type(t) is not ad.Zero: - r = tuple(rule(t, *primals, **params)) - tangents_out.append(r) - assert tree_util.tree_structure(r) == tree - try: - return val_out, functools.reduce(_add_tangents, - tangents_out, - tree_util.tree_map(lambda a: ad.Zero.from_primal_value(a), val_out)) - except: - return val_out, functools.reduce(_add_tangents, - tangents_out, - tree_util.tree_map(lambda a: ad.Zero.from_value(a), val_out)) - - -def _add_tangents(xs, ys): - return tree_util.tree_map(ad.add_tangents, xs, ys, is_leaf=lambda a: isinstance(a, ad.Zero)) - diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py deleted file mode 100644 index a6dd5a5b..00000000 --- a/brainpy/_src/math/op_register/base.py +++ /dev/null @@ -1,224 +0,0 @@ -from functools import partial -from typing import Callable, Sequence, Tuple, Protocol, Optional, Union - -import jax -import numpy as np -from jax.interpreters import xla, batching, ad, mlir - -from brainpy._src.dependency_check import import_numba, import_cupy_jit -from brainpy._src.math.ndarray import Array -from brainpy._src.math.object_transform.base import BrainPyObject - -is_version_right = False -if jax.__version__ >= '0.4.16': - from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule - from braintaichi._primitive._mlir_translation_rule import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, - register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import ( - register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, - register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) - is_version_right = True - -from .utils import register_general_batching -from brainpy._src.math.op_register.ad_support import defjvp - -numba = import_numba(error_if_not_found=False) -cp_jit = import_cupy_jit(error_if_not_found=False) - -__all__ = [ - 'XLACustomOp', -] - - -class ShapeDtype(Protocol): - - @property - def shape(self) -> Tuple[int, ...]: - ... - - @property - def dtype(self) -> np.dtype: - ... - - -class XLACustomOp(BrainPyObject): - """Creating a XLA custom call operator. - - For more information, please refer to the tutorials above: - Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html - Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html - CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html - - Args: - cpu_kernel: Callable. The function defines the computation on CPU backend. - gpu_kernel: Callable. The function defines the computation on GPU backend. - batching_translation: Callable. The batching translation rule of JAX. - jvp_translation: Callable. The JVP translation rule of JAX. - transpose_translation: Callable. The transpose translation rule of JAX. - outs: optional. The output information. - name: str. The primitive name. - """ - - def __init__( - self, - cpu_kernel: Callable = None, - gpu_kernel: Union[Callable, str] = None, - batching_translation: Callable = None, - jvp_translation: Callable = None, - transpose_translation: Callable = None, - outs: Optional[Callable] = None, - name: str = None, - ): - if not is_version_right: - raise ImportError('XLA Custom Op is only supported in JAX>=0.4.16') - super().__init__(name) - - # set cpu_kernel and gpu_kernel - self.cpu_kernel = cpu_kernel - self.gpu_kernel = gpu_kernel - - # primitive - self.primitive = jax.core.Primitive(self.name) - self.primitive.multiple_results = True - - # abstract evaluation - self.outs = outs - self.primitive.def_abstract_eval(_abstract_eval) - self.primitive.def_impl(partial(xla.apply_primitive, self.primitive)) - - # cpu function - cpu_checked = False - if cpu_kernel is None: - cpu_checked = True - if numba is not None: # numba - from numba.core.dispatcher import Dispatcher - if isinstance(cpu_kernel, Dispatcher): - register_numba_cpu_translation_rule(self.primitive, cpu_kernel) - cpu_checked = True - if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi - register_taichi_cpu_translation_rule(self.primitive, cpu_kernel) - cpu_checked = True - if not cpu_checked: - raise ValueError(f'"cpu_kernel" must be a numba jitted function or a taichi kernel function. ' - f'But we got {cpu_kernel}') - - # gpu function - gpu_checked = False - if gpu_kernel is None: - gpu_checked = True - elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule - register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel) - gpu_checked = True - elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel - register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel) - gpu_checked = True - elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi - register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) - gpu_checked = True - if not gpu_checked: - raise ValueError( - f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') - - # batching rule - if batching_translation is None: - register_general_batching(self.primitive) - else: - batching.primitive_batchers[self.primitive] = batching_translation - - # jvp rule - if jvp_translation is not None: - ad.primitive_jvps[self.primitive] = jvp_translation - - # transpose rule - if transpose_translation is not None: - ad.primitive_transposes[self.primitive] = transpose_translation - - def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None, **kwargs): - if outs is None: - if self.outs is None: - raise ValueError('The output information is not defined.') - outs = self.outs(*ins, **kwargs) - assert outs is not None - outs = tuple([_transform_to_shapedarray(o) for o in outs]) - ins = jax.tree_util.tree_map(_transform_to_array, ins, is_leaf=_is_bp_array) - return self.primitive.bind(*ins, outs=outs, **kwargs) - - def def_abstract_eval(self, fun): - """Define the abstract evaluation function. - - Args: - fun: The abstract evaluation function. - """ - self.primitive.def_abstract_eval(fun) - - def def_batching_rule(self, fun): - """Define the batching rule. - - Args: - fun: The batching rule. - """ - batching.primitive_batchers[self.primitive] = fun - - def def_jvp_rule(self, fun): - """Define the JVP rule. - - Args: - fun: The JVP rule. - """ - ad.primitive_jvps[self.primitive] = fun - - def defjvp(self, *jvp_rules): - """Define the JVP rule. Similar to ``jax.interpreters.ad.defjvp``, but supports the Primitive with multiple results. - - Args: - jvp_rules: The JVP rules. - """ - defjvp(self.primitive, *jvp_rules) - - def def_transpose_rule(self, fun): - """Define the transpose rule. - - Args: - fun: The transpose rule. - """ - ad.primitive_transposes[self.primitive] = fun - - def def_xla_translation(self, platform, fun): - """Define the XLA translation rule. - - Args: - platform: str. The computing platform. - fun: The XLA translation rule. - """ - xla.backend_specific_translations[platform][self.primitive] = fun - - def def_mlir_lowering(self, platform, fun): - """Define the MLIR lowering rule. - - Args: - platform: str. The computing platform. - fun: The lowering rule. - """ - mlir.register_lowering(self.primitive, fun, platform) - - -def _abstract_eval(*args, **kwargs): - return [jax.core.ShapedArray(out_shape.shape, out_shape.dtype) - for out_shape in kwargs['outs']] - - -def _is_bp_array(a): - return isinstance(a, Array) - - -def _transform_to_array(a): - if isinstance(a, Array): - return a.value - elif isinstance(a, jax.Array): - return a - else: - return jax.numpy.asarray(a) - - -def _transform_to_shapedarray(a): - return jax.core.ShapedArray(a.shape, a.dtype) diff --git a/brainpy/_src/math/op_register/cupy_based.py b/brainpy/_src/math/op_register/cupy_based.py deleted file mode 100644 index ad6befec..00000000 --- a/brainpy/_src/math/op_register/cupy_based.py +++ /dev/null @@ -1,279 +0,0 @@ -from functools import partial, reduce -from typing import List, Tuple - -import jax -import numpy as np -from jax.interpreters import xla, mlir -from jax.lib import xla_client -from jaxlib.hlo_helpers import custom_call - -from brainpy._src.dependency_check import (import_cupy, - import_cupy_jit, - import_brainpylib_gpu_ops) -from brainpy._src.math.op_register.utils import _shape_to_layout -from brainpy.errors import PackageMissingError - -cp = import_cupy(error_if_not_found=False) -cp_jit = import_cupy_jit(error_if_not_found=False) - -# convert type to number -type_number_map = { - int: 0, - float: 1, - bool: 2, - np.dtype('int32'): 0, - np.dtype('float32'): 1, - np.dtype('bool'): 2, - np.dtype('uint8'): 3, - np.dtype('uint16'): 4, - np.dtype('uint32'): 5, - np.dtype('uint64'): 6, - np.dtype('int8'): 7, - np.dtype('int16'): 8, - np.dtype('int64'): 9, - np.dtype('float16'): 10, - np.dtype('float64'): 11, -} - - -def _preprocess_kernel_call_gpu( - grid: Tuple[int], - block: Tuple[int], - func_ptr: int, - shared_mem: int, - *ins, - outs: List[jax.ShapeDtypeStruct], -): - grid = (grid + (1, 1))[:3] - block = (block + (1, 1))[:3] - in_num = len(ins) - out_num = len(outs) - in_out_num = [in_num, out_num] - - out_type_list = [0] * out_num - out_elem_count_list = [0] * out_num - - for i, value in enumerate(outs): - out_type_list[i] = type_number_map[value.dtype] - out_elem_count_list[i] = reduce(lambda x, y: x * y, value.shape) - - grid = ",".join(str(i) for i in grid) - block = ",".join(str(i) for i in block) - in_out_num_str = ",".join(str(i) for i in in_out_num) - out_type_list_str = ",".join(str(i) for i in out_type_list) - out_elem_count_list_str = ",".join(str(i) for i in out_elem_count_list) - - opaque = (bytes(str(func_ptr), encoding='utf-8') + b';' + - bytes(str(shared_mem), encoding='utf-8') + b';' + - bytes(in_out_num_str, encoding='utf-8') + b';' + - bytes(grid, encoding='utf-8') + b';' + - bytes(block, encoding='utf-8') + b';' + - bytes(out_type_list_str, encoding='utf-8') + b';' + - bytes(out_elem_count_list_str, encoding='utf-8') + b';') - return opaque - - -def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): - grid = kwargs.get('grid', None) - block = kwargs.get('block', None) - shared_mem = kwargs.get('shared_mem', 0) - if grid is None or block is None: - raise ValueError('The grid and block should be specified for the cupy kernel.') - - # preprocess - import_brainpylib_gpu_ops() - # THE KEY: - # - using the kernel pointer at "kernel.kernel.ptr" - opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) - - # create custom call - return xla_client.ops.CustomCallWithLayout( - c, - b'cupy_kernel_call_gpu', - operands=ins, - operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), - shape_with_layout=xla_client.Shape.tuple_shape( - [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) - for value in kwargs['outs']] - ), - opaque=opaque, - ) - - -def register_cupy_raw_module_xla_gpu_translation_rule(primitive, gpu_kernel): - xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_raw_module_xla_gpu_translation_rule, gpu_kernel) - - -def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): - grid = kwargs.get('grid', None) - block = kwargs.get('block', None) - shared_mem = kwargs.get('shared_mem', 0) - if grid is None or block is None: - raise ValueError('The grid and block should be specified for the cupy kernel.') - - # preprocess - import_brainpylib_gpu_ops() - opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) - - input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] - result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] - output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out] - - return custom_call( - call_target_name='cupy_kernel_call_gpu', - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - backend_config=opaque, - has_side_effect=False, - ).results - - -def register_cupy_raw_module_mlir_gpu_translation_rule(primitive, gpu_kernel): - if cp is None: - raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule') - - rule = partial(_cupy_raw_module_mlir_gpu_translation_rule, gpu_kernel) - mlir.register_lowering(primitive, rule, platform='gpu') - - -def _to_cupy_array_or_scalar(dtype, ndim): - # THE KEY - # - using the cupy jit compiler to get the type - if ndim != 0: - t = cp_jit._cuda_types.CArray(dtype=dtype, - ndim=ndim, - is_c_contiguous=True, - index_32_bits=True) - else: - t = cp_jit._cuda_types.Scalar(dtype=dtype) - return t - - -def _compile_kernel_xla(kernel, in_types): - # THE KEY - # - get the kernel function from the cache - device_id = cp.cuda.get_device_id() - kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None)) - - if kern is None: - # THE KEY: - # - compile the kernel function - result = kernel._cached_codes.get(in_types) - if result is None: - result = cp_jit._compile.transpile( - kernel._func, - ['extern "C"', '__global__'], - 'cuda', - in_types, - cp_jit._cuda_types.void, - ) - kernel._cached_codes[in_types] = result - fname = result.func_name - enable_cg = result.enable_cooperative_groups - options = result.options - backend = result.backend - if backend == 'nvcc': - options += ('-DCUPY_JIT_NVCC',) - jitify = result.jitify - module = cp._core.core.compile_with_cache( - source=result.code, - options=options, - backend=backend, - jitify=jitify, - ) - kern = module.get_function(fname) - kernel._cache[(in_types, device_id)] = (kern, enable_cg) - - return kern - - -def get_jit_kernel_xla(kernel, c, *ins, outs): - # get the input types - in_types = [] - for x in ins: - x = c.get_shape(x) - in_types.append(_to_cupy_array_or_scalar(x.element_type(), len(x.dimensions()))) - for x in outs: - in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) - in_types = tuple(in_types) - # compile the kernel - return _compile_kernel_xla(kernel, in_types) - - -def get_jit_kernel_mlir(kernel, c): - # get the input types - in_types = [] - for x in c.avals_in: - in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) - for x in c.avals_out: - in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) - in_types = tuple(in_types) - # compile the kernel - return _compile_kernel_xla(kernel, in_types) - - -def _cupy_jit_kernel_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): - kernel_func = get_jit_kernel_xla(kernel, c, *ins, outs=kwargs['outs']) - grid = kwargs.get('grid', None) - block = kwargs.get('block', None) - shared_mem = kwargs.get('shared_mem', 0) - if grid is None or block is None: - raise ValueError('The grid and block should be specified for the cupy kernel.') - - # preprocess - import_brainpylib_gpu_ops() - opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs']) - - # create custom call - return xla_client.ops.CustomCallWithLayout( - c, - b'cupy_kernel_call_gpu', - operands=ins, - operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), - shape_with_layout=xla_client.Shape.tuple_shape( - [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) - for value in kwargs['outs']] - ), - opaque=opaque, - ) - - -def register_cupy_jit_kernel_xla_gpu_translation_rule(primitive, gpu_kernel): - xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_jit_kernel_xla_gpu_translation_rule, gpu_kernel) - - -def _cupy_jit_kernel_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): - kernel_func = get_jit_kernel_mlir(kernel, c) - grid = kwargs.get('grid', None) - block = kwargs.get('block', None) - shared_mem = kwargs.get('shared_mem', 0) - if grid is None or block is None: - raise ValueError('The grid and block should be specified for the cupy kernel.') - - # preprocess - import_brainpylib_gpu_ops() - opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs']) - - input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] - result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] - output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out] - - return custom_call( - call_target_name='cupy_kernel_call_gpu', - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - backend_config=opaque, - has_side_effect=False, - ).results - - -def register_cupy_jit_kernel_mlir_gpu_translation_rule(primitive, gpu_kernel): - if cp is None: - raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule') - - rule = partial(_cupy_jit_kernel_mlir_gpu_translation_rule, gpu_kernel) - mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py deleted file mode 100644 index 35c9beef..00000000 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ /dev/null @@ -1,295 +0,0 @@ -# -*- coding: utf-8 -*- -import ctypes -import ctypes -from functools import partial -from typing import Callable -from typing import Union, Sequence - -import jax -from jax.interpreters import xla, batching, ad, mlir - -from jax.tree_util import tree_map -from jaxlib.hlo_helpers import custom_call - -from brainpy._src.dependency_check import import_numba -from brainpy._src.math.ndarray import Array -from brainpy._src.math.object_transform.base import BrainPyObject - -from brainpy.errors import PackageMissingError -from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba, _numba_mlir_cpu_translation_rule - -numba = import_numba(error_if_not_found=False) -if numba is not None: - from numba import types, carray, cfunc - -__all__ = [ - 'CustomOpByNumba', - 'register_op_with_numba_xla', - 'compile_cpu_signature_with_numba', -] - - -def _transform_to_shapedarray(a): - return jax.core.ShapedArray(a.shape, a.dtype) - - -def convert_shapedarray_to_shapedtypestruct(shaped_array): - return jax.ShapeDtypeStruct(shape=shaped_array.shape, dtype=shaped_array.dtype) - - -class CustomOpByNumba(BrainPyObject): - """Creating a XLA custom call operator with Numba JIT on CPU backend. - - Parameters - ---------- - name: str - The name of operator. - eval_shape: callable - The function to evaluate the shape and dtype of the output according to the input. - This function should receive the abstract information of inputs, and return the - abstract information of the outputs. For example: - - >>> def eval_shape(inp1_info, inp2_info, inp3_info, ...): - >>> return out1_info, out2_info - con_compute: callable - The function to make the concrete computation. This function receives inputs, - and returns outputs. For example: - - >>> def con_compute(inp1, inp2, inp3, ..., out1, out2, ...): - >>> pass - """ - - def __init__( - self, - eval_shape: Callable = None, - con_compute: Callable = None, - name: str = None, - batching_translation: Callable = None, - jvp_translation: Callable = None, - transpose_translation: Callable = None, - multiple_results: bool = True, - ): - super().__init__(name=name) - - # abstract evaluation function - if eval_shape is None: - raise ValueError('Must provide "eval_shape" for abstract evaluation.') - self.eval_shape = eval_shape - - # cpu function - cpu_func = con_compute - - # register OP - if jax.__version__ > '0.4.23': - self.op_method = 'mlir' - self.op = register_op_with_numba_mlir( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - gpu_func_translation=None, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) - else: - self.op_method = 'xla' - self.op = register_op_with_numba_xla( - self.name, - cpu_func=cpu_func, - out_shapes=eval_shape, - batching_translation=batching_translation, - jvp_translation=jvp_translation, - transpose_translation=transpose_translation, - multiple_results=multiple_results, - ) - - def __call__(self, *args, **kwargs): - args = tree_map(lambda a: a.value if isinstance(a, Array) else a, - args, is_leaf=lambda a: isinstance(a, Array)) - kwargs = tree_map(lambda a: a.value if isinstance(a, Array) else a, - kwargs, is_leaf=lambda a: isinstance(a, Array)) - res = self.op.bind(*args, **kwargs) - return res - - -def register_op_with_numba_xla( - op_name: str, - cpu_func: Callable, - out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], - gpu_func_translation: Callable = None, - batching_translation: Callable = None, - jvp_translation: Callable = None, - transpose_translation: Callable = None, - multiple_results: bool = False, -): - """ - Converting the numba-jitted function in a Jax/XLA compatible primitive. - - Parameters - ---------- - op_name: str - Name of the operators. - - cpu_func: Callable - A callable numba-jitted function or pure function (can be lambda function) running on CPU. - - out_shapes: Callable, ShapedArray, Sequence[ShapedArray], default = None - Outputs shapes of target function. `out_shapes` can be a `ShapedArray` or - a sequence of `ShapedArray`. If it is a function, it takes as input the argument - shapes and dtypes and should return correct output shapes of `ShapedArray`. - - gpu_func_translation: Callable - A callable cuda-jitted kernel running on GPU. - - batching_translation: Callable - The batching translation for the primitive. - - jvp_translation: Callable - The forward autodiff translation rule. - - transpose_translation: Callable - The backward autodiff translation rule. - - multiple_results: bool - Whether the primitive returns multiple results. Default is False. - - Returns - ------- - op: core.Primitive - A JAX Primitive object. - """ - - if numba is None: - raise PackageMissingError.by_purpose('numba', 'custom op with numba') - - if out_shapes is None: - raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' - 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' - 'shapes and dtypes and should return correct output shapes of `ShapedArray`.') - - prim = jax.core.Primitive(op_name) - prim.multiple_results = multiple_results - - # user defined function - from numba.core.dispatcher import Dispatcher - if not isinstance(cpu_func, Dispatcher): - cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) - - # output shape evaluation function - def abs_eval_rule(*input_shapes, **info): - if callable(out_shapes): - shapes = out_shapes(*input_shapes, **info) - else: - shapes = out_shapes - - if isinstance(shapes, jax.core.ShapedArray): - assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." - elif isinstance(shapes, (tuple, list)): - assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." - for elem in shapes: - if not isinstance(elem, jax.core.ShapedArray): - raise ValueError(f'Elements in "out_shapes" must be instances of ' - f'jax.abstract_arrays.ShapedArray, but we got ' - f'{type(elem)}: {elem}') - else: - raise ValueError(f'Unknown type {type(shapes)}, only ' - f'supports function, ShapedArray or ' - f'list/tuple of ShapedArray.') - return shapes - - # cpu function - prim.def_abstract_eval(abs_eval_rule) - prim.def_impl(partial(xla.apply_primitive, prim)) - xla.backend_specific_translations['cpu'][prim] = partial(_cpu_translation, - cpu_func, - abs_eval_rule, - multiple_results) - - # gpu function - if gpu_func_translation is not None: - xla.backend_specific_translations['gpu'][prim] = gpu_func_translation - - # batching - if batching_translation is not None: - batching.primitive_batchers[prim] = batching_translation - - # jvp - if jvp_translation is not None: - ad.primitive_jvps[prim] = jvp_translation - - # transpose - if transpose_translation is not None: - ad.primitive_transposes[prim] = transpose_translation - - return prim - - -def register_op_with_numba_mlir( - op_name: str, - cpu_func: Callable, - out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]], - gpu_func_translation: Callable = None, - batching_translation: Callable = None, - jvp_translation: Callable = None, - transpose_translation: Callable = None, - multiple_results: bool = False, -): - if numba is None: - raise PackageMissingError.by_purpose('numba', 'custom op with numba') - - if out_shapes is None: - raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or ' - 'a sequence of `ShapedArray`. If it is a function, it takes as input the argument ' - 'shapes and dtypes and should return correct output shapes of `ShapedArray`.') - - prim = jax.core.Primitive(op_name) - prim.multiple_results = multiple_results - - from numba.core.dispatcher import Dispatcher - if not isinstance(cpu_func, Dispatcher): - cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func) - - def abs_eval_rule(*input_shapes, **info): - if callable(out_shapes): - shapes = out_shapes(*input_shapes, **info) - else: - shapes = out_shapes - - if isinstance(shapes, jax.core.ShapedArray): - assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data." - elif isinstance(shapes, (tuple, list)): - assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data." - for elem in shapes: - if not isinstance(elem, jax.core.ShapedArray): - raise ValueError(f'Elements in "out_shapes" must be instances of ' - f'jax.abstract_arrays.ShapedArray, but we got ' - f'{type(elem)}: {elem}') - else: - raise ValueError(f'Unknown type {type(shapes)}, only ' - f'supports function, ShapedArray or ' - f'list/tuple of ShapedArray.') - return shapes - - prim.def_abstract_eval(abs_eval_rule) - prim.def_impl(partial(xla.apply_primitive, prim)) - - cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule, - cpu_func, - False) - - mlir.register_lowering(prim, cpu_translation_rule, platform='cpu') - - if gpu_func_translation is not None: - mlir.register_lowering(prim, gpu_func_translation, platform='gpu') - - if batching_translation is not None: - jax.interpreters.batching.primitive_batchers[prim] = batching_translation - - if jvp_translation is not None: - jax.interpreters.ad.primitive_jvps[prim] = jvp_translation - - if transpose_translation is not None: - jax.interpreters.ad.primitive_transposes[prim] = transpose_translation - - return prim diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py deleted file mode 100644 index 363ce6b1..00000000 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ /dev/null @@ -1,228 +0,0 @@ -# -*- coding: utf-8 -*- - -import ctypes - -from jax import dtypes, numpy as jnp -from jax.core import ShapedArray -from jax.lib import xla_client -from jaxlib.hlo_helpers import custom_call -from jax.interpreters import mlir - -from brainpy._src.dependency_check import import_numba -from brainpy._src.math.op_register.utils import _shape_to_layout - -numba = import_numba(error_if_not_found=False) -ctypes.pythonapi.PyCapsule_New.argtypes = [ - ctypes.c_void_p, # void* pointer - ctypes.c_char_p, # const char *name - ctypes.c_void_p, # PyCapsule_Destructor destructor -] -ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - -__all__ = [ - '_cpu_translation', - 'compile_cpu_signature_with_numba', - '_numba_mlir_cpu_translation_rule', -] - -if numba is not None: - from numba import types, carray, cfunc - - -def _cpu_translation(func, abs_eval_fn, multiple_results, c, *inputs, **info): - target_name, inputs, input_shapes, xla_output_shapes = \ - compile_cpu_signature_with_numba(c, func, abs_eval_fn, multiple_results, inputs, info) - return xla_client.ops.CustomCallWithLayout( - c, - target_name, - operands=inputs, - operand_shapes_with_layout=input_shapes, - shape_with_layout=xla_output_shapes, - ) - - -def _cpu_signature( - func, - input_dtypes, - input_shapes, - output_dtypes, - output_shapes, - multiple_results: bool, - debug: bool = False -): - code_scope = dict( - func_to_call=func, - input_shapes=input_shapes, - input_dtypes=input_dtypes, - output_shapes=output_shapes, - output_dtypes=output_dtypes, - carray=carray, - ) - - # inputs - if len(input_shapes) > 1: - args_in = [ - f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' - for i in range(len(input_shapes)) - ] - args_in = '(\n ' + "\n ".join(args_in) + '\n )' - else: - args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' - - # outputs - if multiple_results: - args_out = [ - f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' - for i in range(len(output_shapes)) - ] - args_out = '(\n ' + "\n ".join(args_out) + '\n )' - else: - args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' - - # function body - code_string = ''' -def xla_cpu_custom_call_target(output_ptrs, input_ptrs): - args_out = {args_out} - args_in = {args_in} - func_to_call(args_out, args_in) - '''.format(args_in=args_in, - args_out=args_out) - if debug: print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - - new_f = code_scope['xla_cpu_custom_call_target'] - if multiple_results: - xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), - types.CPointer(types.voidptr)))(new_f) - else: - xla_c_rule = cfunc(types.void(types.voidptr, types.CPointer(types.voidptr)))(new_f) - target_name = xla_c_rule.native_name.encode("ascii") - capsule = ctypes.pythonapi.PyCapsule_New( - xla_c_rule.address, # A CFFI pointer to a function - b"xla._CUSTOM_CALL_TARGET", # A binary string - None # PyCapsule object run at destruction - ) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - return target_name - - -def compile_cpu_signature_with_numba( - c, - func, - abs_eval_fn, - multiple_results, - inputs: tuple, - description: dict = None, -): - input_layouts = [c.get_shape(arg) for arg in inputs] - info_inputs = [] - if description is None: description = dict() - for v in description.values(): - if isinstance(v, (int, float)): - input_layouts.append(xla_client.Shape.array_shape(dtypes.canonicalize_dtype(type(v)), (), ())) - info_inputs.append(xla_client.ops.ConstantLiteral(c, v)) - elif isinstance(v, (tuple, list)): - v = jnp.asarray(v) - input_layouts.append(xla_client.Shape.array_shape(v.dtype, v.shape, tuple(range(len(v.shape) - 1, -1, -1)))) - info_inputs.append(xla_client.ops.Constant(c, v)) - else: - raise TypeError - input_layouts = tuple(input_layouts) - input_dtypes = tuple(shape.element_type() for shape in input_layouts) - input_dimensions = tuple(shape.dimensions() for shape in input_layouts) - output_abstract_arrays = abs_eval_fn(*tuple(ShapedArray(shape.dimensions(), shape.element_type()) - for shape in input_layouts[:len(inputs)]), - **description) - if isinstance(output_abstract_arrays, ShapedArray): - output_abstract_arrays = (output_abstract_arrays,) - assert not multiple_results - else: - assert multiple_results - output_shapes = tuple(array.shape for array in output_abstract_arrays) - output_dtypes = tuple(array.dtype for array in output_abstract_arrays) - output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) - target_name = _cpu_signature(func, - input_dtypes, - input_dimensions, - output_dtypes, - output_shapes, - multiple_results, - debug=False) - output_layouts = [xla_client.Shape.array_shape(*arg) - for arg in zip(output_dtypes, output_shapes, output_layouts)] - output_layouts = (xla_client.Shape.tuple_shape(output_layouts) - if multiple_results else - output_layouts[0]) - return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts - - -def _numba_mlir_cpu_translation_rule( - cpu_func, - debug, - ctx, - *ins, - **kwargs -): - # output information - outs = ctx.avals_out - output_shapes = tuple([out.shape for out in outs]) - output_dtypes = tuple([out.dtype for out in outs]) - output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) - result_types = [mlir.aval_to_ir_type(out) for out in outs] - - # input information - avals_in = ctx.avals_in - input_layouts = [_shape_to_layout(a.shape) for a in avals_in] - input_dtypes = tuple(inp.dtype for inp in avals_in) - input_shapes = tuple(inp.shape for inp in avals_in) - - # compiling function - code_scope = dict(func_to_call=cpu_func, input_shapes=input_shapes, input_dtypes=input_dtypes, - output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) - if len(input_shapes) > 1: - args_in = [ - f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' - for i in range(len(input_shapes)) - ] - args_in = '(\n ' + "\n ".join(args_in) + '\n )' - else: - args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' - if len(output_shapes) > 1: - args_out = [ - f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' - for i in range(len(output_shapes)) - ] - args_out = '(\n ' + "\n ".join(args_out) + '\n )' - sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) - else: - args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' - sig = types.void(types.voidptr, types.CPointer(types.voidptr)) - # args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))] - code_string = ''' -def numba_cpu_custom_call_target(output_ptrs, input_ptrs): - args_out = {args_out} - args_in = {args_in} - func_to_call(args_out, args_in) - '''.format(args_in=args_in, - args_out=args_out) - - if debug: - print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - new_f = code_scope['numba_cpu_custom_call_target'] - - # register - xla_c_rule = cfunc(sig)(new_f) - target_name = f'numba_custom_call_{str(xla_c_rule.address)}' - capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - - # call - return custom_call( - call_target_name=target_name, - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - has_side_effect=False, - ).results diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py deleted file mode 100644 index 21099cb6..00000000 --- a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py +++ /dev/null @@ -1,48 +0,0 @@ -import jax.core -import pytest -from jax.core import ShapedArray - -import brainpy.math as bm -from brainpy._src.dependency_check import import_numba - -numba = import_numba(error_if_not_found=False) -if numba is None: - pytest.skip('no numba', allow_module_level=True) - -bm.set_platform('cpu') - - -def eval_shape(a): - b = ShapedArray(a.shape, dtype=a.dtype) - return b - -@numba.njit(parallel=True) -def con_compute(outs, ins): - b = outs - a = ins - b[:] = a + 1 - -def test_CustomOpByNumba_single_result(): - op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False) - print(op(bm.zeros(10))) - -def eval_shape2(a, b): - c = ShapedArray(a.shape, dtype=a.dtype) - d = ShapedArray(b.shape, dtype=b.dtype) - return c, d - -def con_compute2(outs, ins): - c = outs[0] # take out all the outputs - d = outs[1] - a = ins[0] # take out all the inputs - b = ins[1] - # c, d = outs - # a, b = ins - c[:] = a + 1 - d[:] = b * 2 - -def test_CustomOpByNumba_multiple_results(): - op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True) - print(op2(bm.zeros(10), bm.ones(10))) - -test_CustomOpByNumba_multiple_results() \ No newline at end of file diff --git a/brainpy/_src/math/op_register/numba_based.py b/brainpy/_src/math/op_register/numba_based.py deleted file mode 100644 index f461f427..00000000 --- a/brainpy/_src/math/op_register/numba_based.py +++ /dev/null @@ -1,181 +0,0 @@ -# -*- coding: utf-8 -*- - -import ctypes -from functools import partial - -from jax.interpreters import xla, mlir -from jax.lib import xla_client -from jaxlib.hlo_helpers import custom_call - -from brainpy._src.dependency_check import import_numba -from brainpy.errors import PackageMissingError -from .utils import _shape_to_layout - -numba = import_numba(error_if_not_found=False) -if numba is not None: - from numba import types, carray, cfunc - -__all__ = [ - 'register_numba_xla_cpu_translation_rule', - 'register_numba_mlir_cpu_translation_rule', -] - -# [void* pointer, -# const char *name, -# PyCapsule_Destructor destructor] -ctypes.pythonapi.PyCapsule_New.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_void_p] -ctypes.pythonapi.PyCapsule_New.restype = ctypes.py_object - - -def _cpu_signature( - kernel, - input_dtypes, - input_shapes, - output_dtypes, - output_shapes, - debug: bool = False -): - code_scope = dict( - func_to_call=kernel, - input_shapes=input_shapes, - input_dtypes=input_dtypes, - output_shapes=output_shapes, - output_dtypes=output_dtypes, - carray=carray, - ) - - # inputs, outputs, arguments - args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' - for i in range(len(input_shapes))] - args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' - for i in range(len(output_shapes))] - args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))] - - # function body - code_string = ''' - def xla_cpu_custom_call_target(output_ptrs, input_ptrs): - {args_in} - {args_out} - func_to_call({args_call}) - '''.format(args_in="\n ".join(args_in), - args_out="\n ".join(args_out), - args_call=", ".join(args_call)) - if debug: print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - - # register - new_f = code_scope['xla_cpu_custom_call_target'] - xla_c_rule = cfunc(types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)))(new_f) - target_name = f'numba_custom_call_{str(xla_c_rule.address)}' - capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - - return target_name - - -def _numba_xla_cpu_translation_rule(kernel, debug: bool, c, *ins, **kwargs): - outs = kwargs['outs'] - - # output information - output_shapes = tuple(out.shape for out in outs) - output_dtypes = tuple(out.dtype for out in outs) - output_layouts = map(lambda shape: range(len(shape) - 1, -1, -1), output_shapes) - output_infos = [xla_client.Shape.array_shape(*arg) for arg in zip(output_dtypes, output_shapes, output_layouts)] - output_infos = xla_client.Shape.tuple_shape(output_infos) - - # input information - input_layouts = tuple(c.get_shape(arg) for arg in ins) - input_dtypes = tuple(inp.element_type() for inp in input_layouts) - input_shapes = tuple(inp.dimensions() for inp in input_layouts) - - # compiling - target_name = _cpu_signature(kernel, - input_dtypes, - input_shapes, - output_dtypes, - output_shapes, - debug=debug) - - # call - return xla_client.ops.CustomCallWithLayout( - c, - target_name.encode("ascii"), - operands=tuple(ins), - operand_shapes_with_layout=input_layouts, - shape_with_layout=output_infos, - ) - - -def register_numba_xla_cpu_translation_rule(primitive, cpu_kernel, debug=False): - if numba is None: - raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') - - # do not support after jax >= 0.4.24 - xla.backend_specific_translations['cpu'][primitive] = partial(_numba_xla_cpu_translation_rule, - cpu_kernel, - debug) - - -def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs): - # output information - outs = ctx.avals_out - output_shapes = tuple([out.shape for out in outs]) - output_dtypes = tuple([out.dtype for out in outs]) - output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) - result_types = [mlir.aval_to_ir_type(out) for out in outs] - - # input information - avals_in = ctx.avals_in - input_layouts = [_shape_to_layout(a.shape) for a in avals_in] - input_dtypes = tuple(inp.dtype for inp in avals_in) - input_shapes = tuple(inp.shape for inp in avals_in) - - # compiling function - code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes, - output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) - args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' - for i in range(len(input_shapes))] - if len(output_shapes) > 1: - args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' - for i in range(len(output_shapes))] - sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) - else: - args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'] - sig = types.void(types.voidptr, types.CPointer(types.voidptr)) - args_call = [f'in{i}' for i in range(len(input_shapes))] + [f'out{i}' for i in range(len(output_shapes))] - code_string = ''' -def numba_cpu_custom_call_target(output_ptrs, input_ptrs): - {args_in} - {args_out} - func_to_call({args_call}) - '''.format(args_in="\n ".join(args_in), - args_out="\n ".join(args_out), - args_call=", ".join(args_call)) - if debug: - print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - new_f = code_scope['numba_cpu_custom_call_target'] - - # register - xla_c_rule = cfunc(sig)(new_f) - target_name = f'numba_custom_call_{str(xla_c_rule.address)}' - capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - - # call - return custom_call( - call_target_name=target_name, - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - has_side_effect=False, - ).results - - -def register_numba_mlir_cpu_translation_rule(primitive, cpu_kernel, debug=False): - if numba is None: - raise PackageMissingError.by_purpose("numba", 'register numba xla cpu translation rule') - - rule = partial(_numba_mlir_cpu_translation_rule, cpu_kernel, debug) - mlir.register_lowering(primitive, rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/tests/test_ad_support.py b/brainpy/_src/math/op_register/tests/test_ad_support.py deleted file mode 100644 index 2c9f0972..00000000 --- a/brainpy/_src/math/op_register/tests/test_ad_support.py +++ /dev/null @@ -1,143 +0,0 @@ -import pytest -from typing import Tuple - -import jax -from jax import core -from jax import numpy as jnp -from jax.interpreters import ad - -import brainpy as bp -import brainpy.math as bm -from brainpy._src.dependency_check import import_numba - -numba = import_numba(error_if_not_found=False) -if numba is None: - pytest.skip('no numba', allow_module_level=True) - -bm.set_platform('cpu') - - -def csrmv(data, indices, indptr, vector, *, shape: Tuple[int, int], transpose: bool = False, ): - data = jnp.atleast_1d(bm.as_jax(data)) - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - vector = bm.as_jax(vector) - if vector.dtype == jnp.bool_: - vector = bm.as_jax(vector, dtype=data.dtype) - outs = [core.ShapedArray([shape[1] if transpose else shape[0]], data.dtype)] - if transpose: - return prim_trans(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose) - else: - return prim(data, indices, indptr, vector, outs=outs, shape=shape, transpose=transpose) - - -@numba.njit(fastmath=True) -def _csr_matvec_transpose_numba_imp(values, col_indices, row_ptr, vector, res_val): - res_val.fill(0) - if values.shape[0] == 1: - values = values[0] - for row_i in range(vector.shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += values * v - else: - for row_i in range(vector.shape[0]): - v = vector[row_i] - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - res_val[col_indices[j]] += v * values[j] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _csr_matvec_numba_imp(values, col_indices, row_ptr, vector, res_val): - res_val.fill(0) - # csr mat @ vec - if values.shape[0] == 1: - values = values[0] - for row_i in numba.prange(res_val.shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values * vector[col_indices[j]] - res_val[row_i] = r - else: - for row_i in numba.prange(res_val.shape[0]): - r = 0. - for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += values[j] * vector[col_indices[j]] - res_val[row_i] = r - - -def _csrmv_jvp_mat(data_dot, data, indices, indptr, v, *, shape, transpose, **kwargs): - return csrmv(data_dot, indices, indptr, v, shape=shape, transpose=transpose) - - -def _csrmv_jvp_vec(v_dot, data, indices, indptr, v, *, shape, transpose, **kwargs): - return csrmv(data, indices, indptr, v_dot, shape=shape, transpose=transpose) - - -def _csrmv_cusparse_transpose(ct, data, indices, indptr, vector, *, shape, transpose, **kwargs): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - - ct = ct[0] - if ad.is_undefined_primal(vector): - ct_vector = csrmv(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_vector) - - else: - if type(ct) is ad.Zero: - ct_data = ad.Zero(data) - else: - if data.aval.shape[0] == 1: # scalar - ct_data = csrmv(jnp.ones(1), indices, indptr, vector, shape=shape, transpose=transpose) - ct_data = jnp.inner(ct, ct_data) - else: # heterogeneous values - row, col = bm.sparse.csr_to_coo(indices, indptr) - ct_data = vector[row] * ct[col] if transpose else vector[col] * ct[row] - return ct_data, indices, indptr, vector - - -prim_trans = bm.XLACustomOp(_csr_matvec_transpose_numba_imp) -prim_trans.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec) -prim_trans.def_transpose_rule(_csrmv_cusparse_transpose) - -prim = bm.XLACustomOp(_csr_matvec_numba_imp) -prim.defjvp(_csrmv_jvp_mat, None, None, _csrmv_jvp_vec) -prim.def_transpose_rule(_csrmv_cusparse_transpose) - - -def sum_op(op): - def func(*args, **kwargs): - r = op(*args, **kwargs) - return r.sum() - - return func - - -def try_a_trial(transpose, shape): - rng = bm.random.RandomState() - conn = bp.conn.FixedProb(0.1) - indices, indptr = conn(*shape).require('pre2post') - indices = bm.as_jax(indices) - indptr = bm.as_jax(indptr) - heter_data = rng.random(indices.shape) - heter_data = bm.as_jax(heter_data) - vector = rng.random(shape[0] if transpose else shape[1]) - vector = bm.as_jax(vector) - - r5 = jax.grad(sum_op(lambda *args, **kwargs: bm.sparse.csrmv(*args, **kwargs)), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - r6 = jax.grad(sum_op(lambda *args, **kwargs: csrmv(*args, **kwargs)[0]), argnums=(0, 3))( - heter_data, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) - print(r5) - print(r6) - assert bm.allclose(r5[0], r6[0]) - assert bm.allclose(r5[1], r6[1][0]) - - -def test(): - transposes = [True, False] - shapes = [(100, 200), (10, 1000), (2, 2000)] - - for transpose in transposes: - for shape in shapes: - try_a_trial(transpose, shape) diff --git a/brainpy/_src/math/op_register/tests/test_cupy_based.py b/brainpy/_src/math/op_register/tests/test_cupy_based.py deleted file mode 100644 index 772b6160..00000000 --- a/brainpy/_src/math/op_register/tests/test_cupy_based.py +++ /dev/null @@ -1,79 +0,0 @@ -import jax -import pytest - -import brainpy.math as bm -from brainpy._src.dependency_check import import_cupy, import_cupy_jit, import_taichi - -cp = import_cupy(error_if_not_found=False) -cp_jit = import_cupy_jit(error_if_not_found=False) -ti = import_taichi(error_if_not_found=False) -if cp is None or ti is None: - pytest.skip('no cupy or taichi', allow_module_level=True) -bm.set_platform('cpu') - - -def test_cupy_based(): - bm.op_register.clear_taichi_aot_caches() - # Raw Module - - @ti.kernel - def simpleAdd(x1: ti.types.ndarray(ndim=2), - x2: ti.types.ndarray(ndim=2), - n: ti.types.ndarray(ndim=0), - y: ti.types.ndarray(ndim=2)): - for i, j in y: - y[i, j] = x1[i, j] + x2[i, j] - - source_code = r''' - extern "C"{ - - __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y) - { - unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < N) - { - y[tid] = x1[tid] + x2[tid]; - } - } - } - ''' - N = 10 - x1 = bm.ones((N, N)) - x2 = bm.ones((N, N)) - - mod = cp.RawModule(code=source_code) - kernel = mod.get_function('kernel') - - prim1 = bm.XLACustomOp(cpu_kernel=simpleAdd, gpu_kernel=kernel) - - y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0] - - print(y) - assert bm.allclose(y, x1 + x2) - - # JIT Kernel - @ti.kernel - def elementwise_copy_taichi(x: ti.types.ndarray(ndim=1), - size: ti.types.ndarray(ndim=1), - y: ti.types.ndarray(ndim=1)): - for i in y: - y[i] = x[i] - - @cp_jit.rawkernel() - def elementwise_copy(x, size, y): - tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x - ntid = cp_jit.gridDim.x * cp_jit.blockDim.x - for i in range(tid, size, ntid): - y[i] = x[i] - - size = 100 - x = bm.ones((size,)) - - prim2 = bm.XLACustomOp(cpu_kernel=elementwise_copy_taichi, gpu_kernel=elementwise_copy) - - y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0] - - print(y) - assert bm.allclose(y, x) - -# test_cupy_based() diff --git a/brainpy/_src/math/op_register/tests/test_numba_based.py b/brainpy/_src/math/op_register/tests/test_numba_based.py deleted file mode 100644 index f7adc695..00000000 --- a/brainpy/_src/math/op_register/tests/test_numba_based.py +++ /dev/null @@ -1,55 +0,0 @@ -import jax.core -import pytest -from jax.core import ShapedArray - -import brainpy.math as bm -from brainpy._src.dependency_check import import_numba - -numba = import_numba(error_if_not_found=False) -if numba is None: - pytest.skip('no numba', allow_module_level=True) - -bm.set_platform('cpu') - - -@numba.njit(fastmath=True) -def numba_event_csrmv(weight, indices, vector, outs): - outs.fill(0) - weight = weight[()] # 0d - for row_i in range(vector.shape[0]): - if vector[row_i]: - for j in indices[row_i]: - outs[j] += weight - - -prim = bm.XLACustomOp(numba_event_csrmv) - - -def call(s=100): - indices = bm.random.randint(0, s, (s, 80)) - vector = bm.random.rand(s) < 0.1 - out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)]) - print(out[0].shape) - - -def test_event_ELL(): - call(1000) - call(100) - bm.clear_buffer_memory() - -# CustomOpByNumba Test - -def eval_shape(a): - b = ShapedArray(a.shape, dtype=a.dtype) - return b - -@numba.njit(parallel=True) -def con_compute(outs, ins): - b = outs - a = ins - b[:] = a + 1 - -def test_CustomOpByNumba(): - op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False) - print(op(bm.zeros(10))) - assert bm.allclose(op(bm.zeros(10)), bm.ones(10)) \ No newline at end of file diff --git a/brainpy/_src/math/op_register/utils.py b/brainpy/_src/math/op_register/utils.py deleted file mode 100644 index 2a10443d..00000000 --- a/brainpy/_src/math/op_register/utils.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: utf-8 -*- - - -from functools import partial - -import jax.numpy as jnp -from jax import lax -from jax.interpreters import batching -from jax.tree_util import tree_flatten, tree_unflatten - -__all__ = [ - 'register_general_batching', -] - - -def _general_batching_rule(prim, args, axes, **kwargs): - batch_axes, batch_args, non_batch_args = [], {}, {} - for ax_i, ax in enumerate(axes): - if ax is None: - non_batch_args[f'ax{ax_i}'] = args[ax_i] - else: - batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0) - batch_axes.append(ax_i) - - def f(_, x): - pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}']) - for i in range(len(axes))]) - return 0, prim.bind(*pars, **kwargs) - - _, outs = lax.scan(f, 0, batch_args) - out_vals, out_tree = tree_flatten(outs) - out_dim = tree_unflatten(out_tree, (0,) * len(out_vals)) - return outs, out_dim - - -def register_general_batching(prim): - batching.primitive_batchers[prim] = partial(_general_batching_rule, prim) - - -def _shape_to_layout(shape): - return tuple(range(len(shape) - 1, -1, -1)) - diff --git a/brainpy/_src/math/sparse/__init__.py b/brainpy/_src/math/sparse/__init__.py index 13c9e1e2..eec5f53c 100644 --- a/brainpy/_src/math/sparse/__init__.py +++ b/brainpy/_src/math/sparse/__init__.py @@ -1,9 +1,7 @@ # from ._coo_mv import * -# from ._bsr_mv import * from .csr_mv import * from .csr_mm import * from .utils import * -from .bsr_mm import * from .jax_prim import * diff --git a/brainpy/_src/math/sparse/bsr_mm.py b/brainpy/_src/math/sparse/bsr_mm.py deleted file mode 100644 index 19800749..00000000 --- a/brainpy/_src/math/sparse/bsr_mm.py +++ /dev/null @@ -1,462 +0,0 @@ -# -*- coding: utf-8 -*- - -from functools import partial -from typing import Tuple - -import jax.lax -import numpy as np -from jax import numpy as jnp -from jax.core import Primitive, ShapedArray -from jax.interpreters import ad, xla -from jax.lib import xla_client - -from brainpy._src.dependency_check import import_brainpylib_gpu_ops, import_numba -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching) -from brainpy.errors import GPUOperatorNotFound - -numba = import_numba(error_if_not_found=False) - -__all__ = [ - 'bcsrmm', -] - - -def get_mask(dense_b, blockshape, blockcount): - mask = jnp.zeros(blockcount[0] * blockcount[1], dtype=jnp.bool_) - - for i in range(blockcount[1]): - for j in range(blockcount[0]): - if jnp.abs(dense_b[i * blockshape[1]: (i + 1) * blockshape[1], - j * blockshape[0]: (j + 1) * blockshape[0]]).sum() != 0: - mask = mask.at[i * blockcount[0] + j].set(True) - mask = mask.reshape(blockcount[1], blockcount[0]) - return mask - - -def get_mask_from_ptr_indices(ptr, indices, blockcount): - mask = jnp.zeros((blockcount[1], blockcount[0]), dtype=jnp.bool_) - for idx, indice in enumerate(indices): - row_index = 0 - for ptr_ in ptr[1:]: - if idx < ptr_: - break - row_index += 1 - mask = mask.at[row_index, indice].set(True) - return mask - - -def get_data(dense_b, mask, blockshape, blockcount, n_blocks): - data = jnp.zeros( - shape=(n_blocks * blockshape[1], blockshape[0]), - dtype=jnp.float32 - ) - - assignment_count = 0 - for i in range(blockcount[1]): - for j in range(blockcount[0]): - if mask[i, j]: - data = data.at[assignment_count * blockshape[1]: (assignment_count + 1) * blockshape[1], - :].set(dense_b[i * blockshape[1]: (i + 1) * blockshape[1], - j * blockshape[0]: (j + 1) * blockshape[0]]) - assignment_count += 1 - return data - - -def get_ptr_indices(mask, blockcount, n_blocks, block_ptr=None): - nnz = jnp.nonzero(mask) - - if block_ptr is None: - block_ptr = jnp.arange(0, len(nnz[0])) - - indices = jnp.argsort(block_ptr) - _ = jnp.take(block_ptr, indices) - - blocks = nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)] - blocks = jnp.stack([nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)]], axis=-1).astype( - dtype=jnp.int32 - ) - blocks = jnp.flip(blocks, axis=-1).flatten() - - X = blockcount[1] - Y = blockcount[0] - - rows = nnz[0][:] - cols = nnz[1][:] - - block_indices = jnp.zeros(X * Y, dtype=jnp.int32) - positions = rows * Y + cols - block_indices = block_indices.at[positions].set(block_ptr + 1) - block_indices = block_indices.reshape(X, Y).transpose().reshape(X * Y) - - block_ptr = block_indices[jnp.nonzero(block_indices)[0]] - 1 - - X, Y = Y, X - rows = cols - nnztt = jnp.nonzero(mask.transpose()) - cols = nnztt[:][1] - - rows.astype(jnp.int32) - - ptr_b = jnp.zeros((X + 1,), dtype=jnp.int32) - for row in rows: - ptr_b = ptr_b.at[row + 1].set(ptr_b[row + 1] + 1) - ptr_b = ptr_b.cumsum(0).astype(dtype=jnp.int32) - - indices_b = jnp.stack([cols, block_ptr], axis=1).astype(dtype=jnp.int32) - - return ptr_b, indices_b - - -def get_dense(ptr, indices, data, shape, blockshape): - mask = get_mask_from_ptr_indices(ptr, indices, blockshape) - dense_data = jnp.zeros(shape, dtype=jnp.float32) - mask_count = 0 - for i in range(mask.shape[1]): - for j in range(mask.shape[0]): - if mask[i, j]: - dense_data = dense_data.at[ - i * blockshape[0]: (i + 1) * blockshape[0], - j * blockshape[1]: (j + 1) * blockshape[1], - ].set(data[mask_count * blockshape[0]: (mask_count + 1) * blockshape[0], :]) - mask_count += 1 - return dense_data - - -def blocksparse_matmat_multiply(dense_a, - ptr_b=None, - indices_b=None, - data_b=None, - shape_b=None, - dense_b=None, - blockshape=(32, 32), - device='cpu'): - if dense_b is not None: - # m, n, k - m = dense_a.shape[0] - k = dense_a.shape[1] - n = dense_b.shape[1] - - # blockcount - blockcount = (n // blockshape[0], k // blockshape[1]) - - # mask - mask = get_mask(dense_b, blockshape, blockcount) - - # n_blocks - n_blocks = mask.sum() - - # data_b - data_b = get_data(dense_b, mask, blockshape, blockcount, n_blocks) - - # ptr_b, indices_b - ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks) - else: - # m, n, k - m = dense_a.shape[0] - n = shape_b[1] - k = dense_a.shape[1] - - # blockcount - blockcount = (n // blockshape[0], k // blockshape[1]) - - mask = get_mask_from_ptr_indices(ptr_b, indices_b, blockcount) - - n_blocks = mask.sum() - - ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks) - - # out - # out = jnp.zeros((n, m)) - - # verbose - print('data_b: ', data_b) - print('ptr:', ptr_b) - print('indices:', indices_b) - - '''out = blocksparse_matmat_cpu_test(dense_a, - ptr_b, - indices_b, - data_b, - out, - m=m, - n=n, - k=k, - block_size_k=blockshape[0], - block_size_n=blockshape[1]) - return out''' - - if device == 'cpu': - out = bcsrmm( - dense_a, - ptr_b, - indices_b, - data_b, - m=m, - n=n, - k=k, - block_size_k=blockshape[0], - block_size_n=blockshape[1], - ) - return out - elif device == 'gpu': - out = bcsrmm( - dense_a, - ptr_b, - indices_b, - data_b, - m=m, - n=n, - k=k, - block_size_k=blockshape[0], - block_size_n=blockshape[1], - ) - return out.transpose() - else: - raise Exception('Invalid device: ', device) - - -def bcsrmm( - A_data: jax.Array, - B_data: jax.Array, - B_indices: jax.Array, - B_ptr: jax.Array, - *, - shape: Tuple[int, int], - block_size: Tuple[int, int], - transpose: bool = False, - method: str = 'cutlass' -) -> jax.Array: - """Perform the matrix multiplication :math:`C = A @ B` with BSR data structure. - - Args: - A_data: The dense matrix :math:`A`. - B_data: The data at each block of :math:`B`. - B_indices: The sparse indices of :math:`B`. - B_ptr: The connection pointer of :math:`B`. - shape: a tuple of int, indicating the array shape of :math:`B`. - block_size: a tuple of int, indicating the block size for portioning :math:`B`. - transpose: boolean. If True, perform :math:`A @ B^T`; otherwise, perform :math:`A @ B`. - method: a sting for denoting the BSR sparse computing method. - - Returns: - The dense array :math:`C`. - """ - A_data = as_jax(A_data) - B_data = as_jax(B_data) - B_indices = as_jax(B_indices) - B_ptr = as_jax(B_ptr) - assert A_data.shape[1] == shape[0] - - if method == 'cutlass': - C = _bcsrmm_cutlass_p.bind(A_data, - B_data, - B_indices, - B_ptr, - m=A_data.shape[0], - k=shape[0], - n=shape[1], - transpose=transpose, - block_size_k=block_size[0], - block_size_n=block_size[1])[0] - return C.T - else: - raise ValueError - - -if numba is not None: - @numba.njit(fastmath=True, parallel=True, nogil=True) - def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m) - res_val = outs[0] - # B_data: (num_block, block_size_k, block_size_n) - A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins - block_size_k = block_size_k[()] - block_size_n = block_size_n[()] - n_block = n // block_size_n - - for ni in numba.prange(n_block): - C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) - start, end = B_inptr[ni], B_inptr[ni + 1] - ns = ni * block_size_n - ne = ns + block_size_n - for i in range(start, end): - ki = B_indices[i, 0] - ks = ki * block_size_k - ke = ki + block_size_k - bi = B_indices[i, 1] - C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) - res_val[ns: ne] = C_tmp - return res_val - - - @numba.njit(fastmath=True, parallel=True, nogil=True) - def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m) - res_val = outs[0] - # B_data: (num_block, block_size_n, block_size_k) - A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins - block_size_k = block_size_k[()] - block_size_n = block_size_n[()] - n_block = n // block_size_n - - for ni in numba.prange(n_block): - C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) - start, end = B_inptr[ni], B_inptr[ni + 1] - ns = ni * block_size_n - ne = ns + block_size_n - for i in range(start, end): - ki = B_indices[i, 0] - ks = ki * block_size_k - ke = ki + block_size_k - bi = B_indices[i, 1] - C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) - res_val[ns: ne] = C_tmp - return res_val - - -def _bcsrmm_cutlass_abstract( - A_data, B_data, B_indices, B_ptr, *, m, k, n, block_size_k, block_size_n -): - assert block_size_k == 32, 'cutlass based block-sparse mm only support block size (32, 32)' - assert block_size_n == 32, 'cutlass based block-sparse mm only support block size (32, 32)' - assert B_indices.shape[0] * block_size_n == B_data.shape[0] - assert block_size_k == B_data.shape[1] - assert A_data.shape[0] == m - assert A_data.shape[1] == k - assert A_data.dtype == B_data.dtype - assert n // block_size_n + 1 == B_ptr.shape[0] - return [ShapedArray(dtype=A_data.dtype, shape=(n, m))] - - -def _bcsrmm_cutlass_cpu_translation( - c, A_data, B_data, B_indices, B_ptr, *, - m, k, n, block_size_k, block_size_n -): - inputs = (A_data, B_ptr, B_indices, B_data) - description = dict(m=m, - n=n, - k=k, - block_size_k=block_size_k, - block_size_n=block_size_n) - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _bcsrmm_cutlass_imp2, - abs_eval_fn=_bcsrmm_cutlass_abstract, - multiple_results=True, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _bcsrmm_cutlass_gpu_translation(c, A_data, B_data, B_indices, B_ptr, *, m, k, n, block_size_k, block_size_n): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_bcsrmm_cutlass_p.name) - - matrix_info = c.get_shape(A_data) - dtype = matrix_info.element_type() - - opaque = gpu_ops.build_blocksparse_format_descriptor(m, - n, - k, - block_size_k, - block_size_n) - - fn = b'gpu_blocksparse_matmat' - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(A_data, B_ptr, B_indices, B_data,), - operand_shapes_with_layout=(c.get_shape(A_data), - c.get_shape(B_ptr), - c.get_shape(B_indices), - c.get_shape(B_data),), - shape_with_layout=xla_client.Shape.tuple_shape( - (xla_client.Shape.array_shape(dtype, (n, m), (1, 0)),) - ), - opaque=opaque - ) - - -def _bcsrmm_cutlass_jvp_dense_a(dense_a_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_k, - block_size_n): - return bcsrmm(dense_a_dot, B_ptr, B_indices, B_data, m=m, n=n, k=k, block_size_k=block_size_k, - block_size_n=block_size_n) - - -def _bcsrmm_cutlass_jvp_data_b(data_b_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_k, - block_size_n): - return bcsrmm(A_data, B_ptr, B_indices, data_b_dot, m=m, n=n, k=k, block_size_k=block_size_k, - block_size_n=block_size_n) - - -def _bcsrmm_cutlass_jvp_transpose(): - # TODO: implement - pass - - -_bcsrmm_cutlass_p = Primitive('bcsrmm_cutlass_pim') -_bcsrmm_cutlass_p.multiple_results = True -_bcsrmm_cutlass_p.def_abstract_eval(_bcsrmm_cutlass_abstract) -_bcsrmm_cutlass_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_p)) -# xla.backend_specific_translations['cpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_cpu_translation -# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_p] = _bcsrmm_cutlass_gpu_translation -ad.primitive_jvps[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose -ad.primitive_transposes[_bcsrmm_cutlass_p] = _bcsrmm_cutlass_jvp_transpose -register_general_batching(bcsrmm) - - -def _blocksparse_matmat_back_abstract( - A_data, B_data, blocks, *, m, n, k, transpose, block_size_k, block_size_n, blocks_len -): - shape = (n, k) - dtype = A_data.dtype - out = ShapedArray(dtype=dtype, shape=shape) - return [out] - - -def _blocksparse_matmat_back_gpu_translation( - c, A_data, B_data, blocks, *, m, n, k, transpose, block_size_k, block_size_n, blocks_len -): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(_bcsrmm_cutlass_back_p.name) - matrix_info = c.get_shape(A_data) - dtype = matrix_info.element_type() - - opaque = gpu_ops.build_blocksparse_back_format_descriptor(m, - n, - k, - block_size_k, - block_size_n, - blocks_len) - - fn = b'gpu_blocksparse_matmat_back' - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(A_data, B_data, blocks,), - operand_shape_with_layout=(c.get_shape(A_data), - c.get_shape(B_data), - c.get_shape(blocks),), - shape_with_layout=xla_client.Shape.tuple_shape( - (xla_client.Shape.array_shape(dtype, (k, n), (1, 0)),) - ), - opaque=opaque - ) - - -_bcsrmm_cutlass_back_p = Primitive('bcsrmm_cutlass_back_prim') -_bcsrmm_cutlass_back_p.multiple_results = True -_bcsrmm_cutlass_back_p.def_abstract_eval(_blocksparse_matmat_back_abstract) -_bcsrmm_cutlass_back_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_back_p)) -# xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_back_p] = _blocksparse_matmat_back_gpu_translation -register_general_batching(_bcsrmm_cutlass_back_p) diff --git a/brainpy/_src/math/sparse/bsr_mv.py b/brainpy/_src/math/sparse/bsr_mv.py deleted file mode 100644 index 7dc0b683..00000000 --- a/brainpy/_src/math/sparse/bsr_mv.py +++ /dev/null @@ -1,210 +0,0 @@ -from functools import partial -from typing import Union, Tuple - -import numba -import numpy as np -from jax import numpy as jnp -from jax.core import ShapedArray, Primitive -from jax.interpreters import ad, xla -from jax.lib import xla_client - -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import (compile_cpu_signature_with_numba, - register_general_batching) -from brainpy._src.math.sparse.utils import csr_to_coo -from brainpy._src.dependency_check import import_brainpylib_gpu_ops -from brainpy.errors import GPUOperatorNotFound - -__all__ = [ - 'cusparse_bcsr_matvec' -] - - -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _cusparse_bcsr_matvec_bsr_matvec_numba_imp(outs, ins): - data, indices, indptr, vector, blocksize, shape, nnzb, transpose = ins - blocksize = blocksize[()] - outs.fill(0) - for i in range(shape[0]): - tmp = np.zeros(blocksize, dtype=data.dtype) - for j in range(indptr[i], indptr[i + 1]): - start = indices[j] * blocksize - end = start + blocksize - tmp += data[start: end] @ vector[start: end] - outs[i * blocksize: (i + 1) * blocksize] = tmp - - -# @numba.njit(fastmath=True, parallel=True, nogil=True) -# def _cusparse_bcsr_matvec_bsr_matvec_numba_imp(outs, ins): -# data, indices, indptr, vector, blocksize , shape,nnzb,transpose = ins -# blocksize = blocksize[()] -# outs.fill(0) - -# cnt=0 -# for i in range(0,shape[0]): -# outs.fill(0.0) -# tmp=[0.0]*blocksize -# for j in range(indptr[i], indptr[i + 1]): -# for p in range(0,blocksize): -# for q in range(0,blocksize): -# tmp[p] += vector[indices[j]*blocksize+q]*data[j*blocksize+p][q] -# for j in range(0,blocksize): -# outs[cnt] = tmp[j] -# cnt+=1 - - -def _cusprase_bcsr_matvec_values(values, indices, indptr, vector, *, blocksize, nnzb, shape, transpose): - return cusparse_bcsr_matvec(values, indices, indptr, vector, blocksize, nnzb=nnzb, shape=shape, transpose=transpose) - - -def cusparse_bcsr_matvec( - data: Union[float, jnp.ndarray], - indices: jnp.ndarray, - indptr: jnp.ndarray, - vector: jnp.ndarray, - *, - blocksize: int, - nnzb: int, - shape: Tuple[int, int], - method: str = 'vector', - transpose: bool = False -) -> jnp.ndarray: - data = as_jax(data) - indices = as_jax(indices) - indptr = as_jax(indptr) - vector = as_jax(vector) - if method not in ['scalar', 'vector', 'adaptive']: - raise ValueError('Only support methods: scalar, vector, and adaptive. ' - f'But we got {method}.') - - data = jnp.atleast_1d(data) - if not isinstance(data, jnp.ndarray): - raise TypeError(f'data must a ndarray. But we got {type(data)}') - if data.dtype not in [jnp.float32, jnp.float64]: - raise TypeError(f'Only support float32 and float64. But we got {data.dtype}.') - if data.dtype != vector.dtype: - raise TypeError('The types of data and vector should be the same. ' - f'But we got {data.dtype} != {vector.dtype}.') - # assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - - return cusparse_bcsr_matvec_vector_p.bind(data, indices, indptr, vector, blocksize=blocksize, shape=shape, nnzb=nnzb, - transpose=transpose) - - -def _cusparse_bcsr_matvec_vector_cpu_translation(c, data, indices, indptr, vector, *, blocksize, shape, nnzb, - transpose): - inputs = (data, indices, indptr, vector) - print(c.get_shape(data)) - description = dict(blocksize=blocksize, shape=shape, nnzb=nnzb, transpose=transpose, ) - if transpose: - skip = 1 - else: - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _cusparse_bcsr_matvec_bsr_matvec_numba_imp, - abs_eval_fn=_cusparse_bcsr_matvec_abstract, - multiple_results=False, - inputs=inputs, - description=description - ) - return xla_client.ops.CustomCallWithLayout( - c, name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, - ) - - -def _cusparse_bcsr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *, blocksize, shape, nnzb): - gpu_ops = import_brainpylib_gpu_ops() - if gpu_ops is None: - raise GPUOperatorNotFound(cusparse_bcsr_matvec_vector_p.name) - - data_shape = c.get_shape(data) - if data_shape.element_type() == np.float32: - type_name = b'float' - elif data_shape.element_type() == np.double: - type_name = b'double' - else: - raise ValueError('data_type not support(except float/double)') - # 有可能不是这个 - - opaque = gpu_ops.build_bcsrcusparsespmv_descriptor(shape[0], shape[1], blocksize, nnzb) - return xla_client.ops.CustomCallWithLayout( - c, - b'gpu_bcsr_cusparse_spmv_' + type_name, - operands=(data, indices, indptr, vector), - operand_shapes_with_layout=(c.get_shape(data), - c.get_shape(indices), - c.get_shape(indptr), - c.get_shape(vector), - ), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0] * blocksize,), (0,)), - opaque=opaque, - ) - - -# def _bcsr_matvec_abstract(*args, **kwargs): -# data = args[0] -# assert len(kwargs) == 1 -# shape = kwargs['shape'] -# return ShapedArray(dtype=data.dtype, shape=(shape[0],)) - -# bcsr_matvec_vector_p = register_op_with_numba( -# 'bcsr_matvec_vector', -# cpu_func=None, -# out_shapes=_bcsr_matvec_abstract, -# gpu_func_translation=_bcsr_matvec_vector_gpu_translation, -# ) - - -# def _batch_bcsr_matvec_abstract( -# values, indices, indptr, vector,block_size, *, shape, transpose=False -# ): -# return ShapedArray(dtype=values.dtype, shape=(batch_size, shape[1] if transpose else shape[0])) - -def _cusparse_bcsr_matvec_abstract(data, indices, indptr, vector, *, blocksize, shape, nnzb, transpose=False): - return ShapedArray(dtype=data.dtype, shape=(shape[0] * blocksize,)) - - -def _cusparse_bcsr_matvec_jvp_values(data_dot, data, indices, indptr, vector, *, blocksize, shape, nnzb, transpose): - return cusparse_bcsr_matvec(data_dot, indices, indptr, vector, blocksize=blocksize, nnzb=nnzb, shape=shape, - transpose=transpose) - - -def _cusparse_bcsr_transpose(ct, data, indices, indptr, vector, *, blocksize, shape, transpose): - if ad.is_undefined_primal(indices) or ad.is_undefined_primal(indptr): - raise ValueError("Cannot transpose with respect to sparse indices.") - if ad.is_undefined_primal(vector): - ct_events = cusparse_bcsr_matvec(data, indices, indptr, ct, shape=shape, transpose=not transpose) - return data, indices, indptr, (ad.Zero(vector) if type(ct) is ad.Zero else ct_events) - else: - if type(ct) is ad.Zero: - ct_values = ad.Zero(data) - else: - row, col = csr_to_coo(indices, indptr) - cnt = 0 - ct_values = [] - for i in row: - for j in col: - for p in range(0, blocksize): - cntq = 0 - for q in range(0, blocksize): - if transpose: - ct_values[cnt][cntq] = vector[i * blocksize + p] * ct[j * blocksize + q] - else: - ct_values[cnt][cntq] = vector[j * blocksize + q] * ct[i * blocksize + p] - cntq += 1 - cnt += 1 - return ct_values, indices, indptr, vector - - -cusparse_bcsr_matvec_vector_p = Primitive('cusparse_block_spmv') -cusparse_bcsr_matvec_vector_p.def_abstract_eval(_cusparse_bcsr_matvec_abstract) -cusparse_bcsr_matvec_vector_p.def_impl(partial(xla.apply_primitive, cusparse_bcsr_matvec_vector_p)) -# xla.backend_specific_translations['gpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_gpu_translation -# xla.backend_specific_translations['cpu'][cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_matvec_vector_cpu_translation -ad.defjvp(cusparse_bcsr_matvec_vector_p, _cusparse_bcsr_matvec_jvp_values) -ad.primitive_transposes[cusparse_bcsr_matvec_vector_p] = _cusparse_bcsr_transpose -register_general_batching(cusparse_bcsr_matvec_vector_p) -# batching.primitive_batchers[event_csr_matvec_p] = _event_csr_matvec_batching_rule diff --git a/brainpy/_src/math/sparse/utils.py b/brainpy/_src/math/sparse/utils.py index f5b74e5e..38cfdb7b 100644 --- a/brainpy/_src/math/sparse/utils.py +++ b/brainpy/_src/math/sparse/utils.py @@ -1,22 +1,46 @@ # -*- coding: utf-8 -*- import warnings +from functools import partial from typing import Tuple import numpy as np +from brainpy._src.math.interoperability import as_jax from jax import core, numpy as jnp +from jax import lax +from jax.interpreters import batching from jax.interpreters import mlir, ad +from jax.tree_util import tree_flatten, tree_unflatten from jaxlib import gpu_sparse -from brainpy._src.math.interoperability import as_jax -from brainpy._src.math.op_register import register_general_batching - __all__ = [ 'coo_to_csr', 'csr_to_coo', 'csr_to_dense' ] +def _general_batching_rule(prim, args, axes, **kwargs): + batch_axes, batch_args, non_batch_args = [], {}, {} + for ax_i, ax in enumerate(axes): + if ax is None: + non_batch_args[f'ax{ax_i}'] = args[ax_i] + else: + batch_args[f'ax{ax_i}'] = args[ax_i] if ax == 0 else jnp.moveaxis(args[ax_i], ax, 0) + batch_axes.append(ax_i) + + def f(_, x): + pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}']) + for i in range(len(axes))]) + return 0, prim.bind(*pars, **kwargs) + + _, outs = lax.scan(f, 0, batch_args) + out_vals, out_tree = tree_flatten(outs) + out_dim = tree_unflatten(out_tree, (0,) * len(out_vals)) + return outs, out_dim + +def _register_general_batching(prim): + batching.primitive_batchers[prim] = partial(_general_batching_rule, prim) + def coo_to_csr( pre_ids: jnp.ndarray, @@ -153,6 +177,6 @@ def _csr_to_dense_transpose(ct, data, indices, indptr, *, shape): ad.defjvp(csr_to_dense_p, _csr_to_dense_jvp, None, None) ad.primitive_transposes[csr_to_dense_p] = _csr_to_dense_transpose mlir.register_lowering(csr_to_dense_p, _csr_to_dense_lowering) -register_general_batching(csr_to_dense_p) +_register_general_batching(csr_to_dense_p) if gpu_sparse.cuda_is_supported: mlir.register_lowering(csr_to_dense_p, _csr_to_dense_gpu_lowering, platform='cuda') diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py index 139ec08a..562c1cc1 100644 --- a/brainpy/math/__init__.py +++ b/brainpy/math/__init__.py @@ -16,7 +16,6 @@ # operators from .pre_syn_post import * -from .op_register import * from . import surrogate, event, sparse, jitconn # Variable and Objects for object-oriented JAX transformations diff --git a/brainpy/math/op_register.py b/brainpy/math/op_register.py deleted file mode 100644 index 8ec7f5e1..00000000 --- a/brainpy/math/op_register.py +++ /dev/null @@ -1,10 +0,0 @@ -# -*- coding: utf-8 -*- -from brainpy._src.math.op_register import ( - CustomOpByNumba, - compile_cpu_signature_with_numba, -) - -from brainpy._src.math.op_register.base import XLACustomOp -from brainpy._src.math.op_register.ad_support import defjvp - -