From e3a854af9589ae761e4df357ec3fa47d8750c707 Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Wed, 15 May 2024 21:04:32 +0800 Subject: [PATCH] [math] Fix `CustomOpByNumba` on `multiple_results=True` (#671) * [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24 * Update dependency_check.py * Update dependency_check.py * Update requirements-dev.txt * Update * Update operator_custom_with_numba.ipynb * Update __init__.py * Update dependency_check.py * Update __init__.py * Fix * Update docs * Update operator_custom_with_taichi.ipynb --- brainpy/_src/dependency_check.py | 6 +- .../op_register/numba_approach/__init__.py | 68 +--------- .../numba_approach/cpu_translation.py | 76 +++++++++++ .../tests/test_numba_approach.py | 48 +++++++ .../operator_custom_with_numba.ipynb | 120 ++++++++++++++++-- .../operator_custom_with_taichi.ipynb | 2 +- 6 files changed, 241 insertions(+), 79 deletions(-) create mode 100644 brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 75c2051f..05a7c79c 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -25,9 +25,9 @@ brainpylib_cpu_ops = None brainpylib_gpu_ops = None -taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. ' - f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n' - '> pip install taichi==1.7.0') +taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. ' + f'Currently you can install taichi=={_minimal_taichi_version} by pip . \n' + '> pip install taichi -U') numba_install_info = ('We need numba. Please install numba by pip . \n' '> pip install numba') cupy_install_info = ('We need cupy. Please install cupy by pip . \n' diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 8d5cd3de..35c9beef 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -1,21 +1,22 @@ # -*- coding: utf-8 -*- import ctypes +import ctypes from functools import partial from typing import Callable from typing import Union, Sequence import jax from jax.interpreters import xla, batching, ad, mlir -from jax.lib import xla_client + from jax.tree_util import tree_map from jaxlib.hlo_helpers import custom_call from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject -from brainpy._src.math.op_register.utils import _shape_to_layout + from brainpy.errors import PackageMissingError -from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba +from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba, _numba_mlir_cpu_translation_rule numba = import_numba(error_if_not_found=False) if numba is not None: @@ -224,62 +225,6 @@ def abs_eval_rule(*input_shapes, **info): return prim -def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs): - # output information - outs = ctx.avals_out - output_shapes = tuple([out.shape for out in outs]) - output_dtypes = tuple([out.dtype for out in outs]) - output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) - result_types = [mlir.aval_to_ir_type(out) for out in outs] - - # input information - avals_in = ctx.avals_in - input_layouts = [_shape_to_layout(a.shape) for a in avals_in] - input_dtypes = tuple(inp.dtype for inp in avals_in) - input_shapes = tuple(inp.shape for inp in avals_in) - - # compiling function - code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes, - output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) - args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' - for i in range(len(input_shapes))] - if len(output_shapes) > 1: - args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' - for i in range(len(output_shapes))] - sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) - else: - args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'] - sig = types.void(types.voidptr, types.CPointer(types.voidptr)) - args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))] - code_string = ''' - def numba_cpu_custom_call_target(output_ptrs, input_ptrs): - {args_out} - {args_in} - func_to_call({args_call}) - '''.format(args_out="\n ".join(args_out), args_in="\n ".join(args_in), args_call=", ".join(args_call)) - - if debug: - print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - new_f = code_scope['numba_cpu_custom_call_target'] - - # register - xla_c_rule = cfunc(sig)(new_f) - target_name = f'numba_custom_call_{str(xla_c_rule.address)}' - capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - - # call - return custom_call( - call_target_name=target_name, - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - has_side_effect=False, - ).results - - def register_op_with_numba_mlir( op_name: str, cpu_func: Callable, @@ -329,8 +274,9 @@ def abs_eval_rule(*input_shapes, **info): prim.def_abstract_eval(abs_eval_rule) prim.def_impl(partial(xla.apply_primitive, prim)) - def cpu_translation_rule(ctx, *ins, **kwargs): - return _numba_mlir_cpu_translation_rule(cpu_func, False, ctx, *ins, **kwargs) + cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule, + cpu_func, + False) mlir.register_lowering(prim, cpu_translation_rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 4b06effd..363ce6b1 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -5,8 +5,11 @@ from jax import dtypes, numpy as jnp from jax.core import ShapedArray from jax.lib import xla_client +from jaxlib.hlo_helpers import custom_call +from jax.interpreters import mlir from brainpy._src.dependency_check import import_numba +from brainpy._src.math.op_register.utils import _shape_to_layout numba = import_numba(error_if_not_found=False) ctypes.pythonapi.PyCapsule_New.argtypes = [ @@ -19,6 +22,7 @@ __all__ = [ '_cpu_translation', 'compile_cpu_signature_with_numba', + '_numba_mlir_cpu_translation_rule', ] if numba is not None: @@ -150,3 +154,75 @@ def compile_cpu_signature_with_numba( if multiple_results else output_layouts[0]) return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts + + +def _numba_mlir_cpu_translation_rule( + cpu_func, + debug, + ctx, + *ins, + **kwargs +): + # output information + outs = ctx.avals_out + output_shapes = tuple([out.shape for out in outs]) + output_dtypes = tuple([out.dtype for out in outs]) + output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) + result_types = [mlir.aval_to_ir_type(out) for out in outs] + + # input information + avals_in = ctx.avals_in + input_layouts = [_shape_to_layout(a.shape) for a in avals_in] + input_dtypes = tuple(inp.dtype for inp in avals_in) + input_shapes = tuple(inp.shape for inp in avals_in) + + # compiling function + code_scope = dict(func_to_call=cpu_func, input_shapes=input_shapes, input_dtypes=input_dtypes, + output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) + if len(input_shapes) > 1: + args_in = [ + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' + for i in range(len(input_shapes)) + ] + args_in = '(\n ' + "\n ".join(args_in) + '\n )' + else: + args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' + if len(output_shapes) > 1: + args_out = [ + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' + for i in range(len(output_shapes)) + ] + args_out = '(\n ' + "\n ".join(args_out) + '\n )' + sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) + else: + args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' + sig = types.void(types.voidptr, types.CPointer(types.voidptr)) + # args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))] + code_string = ''' +def numba_cpu_custom_call_target(output_ptrs, input_ptrs): + args_out = {args_out} + args_in = {args_in} + func_to_call(args_out, args_in) + '''.format(args_in=args_in, + args_out=args_out) + + if debug: + print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + new_f = code_scope['numba_cpu_custom_call_target'] + + # register + xla_c_rule = cfunc(sig)(new_f) + target_name = f'numba_custom_call_{str(xla_c_rule.address)}' + capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + + # call + return custom_call( + call_target_name=target_name, + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + has_side_effect=False, + ).results diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py new file mode 100644 index 00000000..21099cb6 --- /dev/null +++ b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py @@ -0,0 +1,48 @@ +import jax.core +import pytest +from jax.core import ShapedArray + +import brainpy.math as bm +from brainpy._src.dependency_check import import_numba + +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) + +bm.set_platform('cpu') + + +def eval_shape(a): + b = ShapedArray(a.shape, dtype=a.dtype) + return b + +@numba.njit(parallel=True) +def con_compute(outs, ins): + b = outs + a = ins + b[:] = a + 1 + +def test_CustomOpByNumba_single_result(): + op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False) + print(op(bm.zeros(10))) + +def eval_shape2(a, b): + c = ShapedArray(a.shape, dtype=a.dtype) + d = ShapedArray(b.shape, dtype=b.dtype) + return c, d + +def con_compute2(outs, ins): + c = outs[0] # take out all the outputs + d = outs[1] + a = ins[0] # take out all the inputs + b = ins[1] + # c, d = outs + # a, b = ins + c[:] = a + 1 + d[:] = b * 2 + +def test_CustomOpByNumba_multiple_results(): + op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True) + print(op2(bm.zeros(10), bm.ones(10))) + +test_CustomOpByNumba_multiple_results() \ No newline at end of file diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index e1121f5b..7f00cd56 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -65,8 +65,6 @@ "source": [ "### ``brainpy.math.CustomOpByNumba``\n", "\n", - "``brainpy.math.CustomOpByNumba`` is also called ``brainpy.math.XLACustomOp``.\n", - "\n", "BrainPy provides ``brainpy.math.CustomOpByNumba`` for customizing the operator on the CPU device. Two parameters are required to provide in ``CustomOpByNumba``:\n", "\n", "- ``eval_shape``: evaluates the *shape* and *datatype* of the output argument based on the *shape* and *datatype* of the input argument.\n", @@ -137,7 +135,7 @@ "collapsed": false }, "source": [ - "### Return multiple values ``multiple_returns=True``\n", + "#### Return multiple values ``multiple_returns=True``\n", "\n", "If the result of our computation needs to return multiple arrays, then we need to use ``multiple_returns=True`` in our use of registering the operator. In this case, ``outs`` will be a list containing multiple arrays, not an array.\n", "\n", @@ -149,8 +147,10 @@ " return c, d\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " c = outs[0] # take out all the outputs\n", + " d = outs[1]\n", + " a = ins[0] # take out all the inputs\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -170,7 +170,7 @@ "collapsed": false }, "source": [ - "### Non-Tracer parameters\n", + "#### Non-Tracer parameters\n", "\n", "In the ``eval_shape`` function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by ``jax.jit``. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.\n", "\n", @@ -191,7 +191,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # Take out all the outputs\n", - " a, b = ins # Take out all inputs\n", + " a = ins[0] # Take out all inputs\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -221,7 +222,7 @@ "collapsed": false }, "source": [ - "### Example: A sparse operator\n", + "#### Example: A sparse operator\n", "\n", "To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator." ] @@ -297,6 +298,50 @@ "f(1.)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### brainpy.math.XLACustomOp\n", + "\n", + "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit, and then pass the kernel to `brainpy.math.XLACustomOp`.\n", + "\n", + "Detailed steps are as follows:\n", + "\n", + "#### Define the kernel\n", + "\n", + "```python\n", + "@numba.njit(fastmath=True)\n", + "def numba_event_csrmv(weight, indices, vector, outs):\n", + " outs.fill(0)\n", + " weight = weight[()] # 0d\n", + " for row_i in range(vector.shape[0]):\n", + " if vector[row_i]:\n", + " for j in indices[row_i]:\n", + " outs[j] += weight\n", + "```\n", + "\n", + "In the declaration of parameters, the last few parameters need to be output parameters so that numba can compile correctly. This operator numba_event_csrmv receives four parameters: `weight`, `indices`, `vector`, and `outs`. The first three parameters are input parameters, and the last parameter is the output parameter. The output parameter is a 1D array, and the input parameters are 0D, 1D, and 2D arrays, respectively." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Registering and Using Custom Operators\n", + "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n", + "\n", + "Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n", + "\n", + "```python\n", + "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n", + "indices = bm.random.randint(0, s, (s, 80))\n", + "vector = bm.random.rand(s) < 0.1\n", + "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n", + "print(out)\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": { @@ -423,7 +468,7 @@ "collapsed": false }, "source": [ - "### 返回多个值 ``multiple_returns=True``\n", + "#### 返回多个值 ``multiple_returns=True``\n", "\n", "如果我们的计算结果需要返回多个数组,那么,我们在注册算子的使用需要使用``multiple_returns=True``。此时,``outs``将会是一个包含多个数组的列表,而不是一个数组。\n", "\n", @@ -434,8 +479,10 @@ " return c, d # 返回多个抽象数组信息\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " c = outs[0] # 取出所有的输出\n", + " d = outs[1]\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -455,7 +502,7 @@ "collapsed": false }, "source": [ - "### 非Tracer参数\n", + "#### 非Tracer参数\n", "\n", "在``eval_shape``函数中推断数据类型时,如果所有参数都是可以被``jax.jit``追踪的参数,那么所有参数都是抽象信息(只包含形状和类型)。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息,此时我们需要定义非Tracer参数。\n", "\n", @@ -476,7 +523,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -506,7 +554,7 @@ "collapsed": false }, "source": [ - "### 示例:一个稀疏算子\n", + "#### 示例:一个稀疏算子\n", "\n", "为了说明这种方法的有效性,我们在这个定义一个事件驱动的稀疏计算算子。" ] @@ -581,6 +629,50 @@ "f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n", "f(1.)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### brainpy.math.XLACustomOp\n", + "\n", + "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using `@numba.jit` or `@numba.njit` decorator, and then pass the kernel to `brainpy.math.XLACustomOp`.\n", + "`brainpy.math.XLACustomOp`是一种自定义算子的新方法。它类似于`brainpy.math.CustomOpByNumba`,但它更灵活并支持更高级的特性。如果您想用numba使用这种新方法,只需要使用 `@numba.jit`或`@numba.njit`装饰器定义一个kernel,然后将内核传递给`brainpy.math.XLACustomOp`。\n", + "\n", + "详细步骤如下:\n", + "\n", + "#### 定义kernel\n", + "在参数声明中,最后几个参数需要是输出参数,这样numba才能正确编译。这个算子`numba_event_csrmv`接受四个参数:weight、indices、vector 和 outs。前三个参数是输入参数,最后一个参数是输出参数。输出参数是一个一维数组,输入参数分别是 0D、1D 和 2D 数组。\n", + "\n", + "```python\n", + "@numba.njit(fastmath=True)\n", + "def numba_event_csrmv(weight, indices, vector, outs):\n", + " outs.fill(0)\n", + " weight = weight[()] # 0d\n", + " for row_i in range(vector.shape[0]):\n", + " if vector[row_i]:\n", + " for j in indices[row_i]:\n", + " outs[j] += weight\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 注册并使用自定义算子\n", + "在定义了自定义算子之后,可以将其注册到特定框架中,并在需要的地方使用它。在注册时可以指定`cpu_kernel`和`gpu_kernel`,这样算子就可以在不同的设备上运行。并在调用中指定`outs`参数,用`jax.ShapeDtypeStruct`来指定输出的形状和数据类型。\n", + "\n", + "注意: 在算子声明的参数与调用时需要保持顺序的一致。\n", + "\n", + "```python\n", + "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n", + "indices = bm.random.randint(0, s, (s, 80))\n", + "vector = bm.random.rand(s) < 0.1\n", + "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n", + "print(out)\n", + "```" + ] } ], "metadata": { diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb index 4b86a426..e927bf72 100644 --- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb @@ -127,7 +127,7 @@ "metadata": {}, "source": [ "### Registering and Using Custom Operators\n", - "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output.\n", + "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n", "\n", "Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n", "\n",