From e55a9b0ea9cd34e1577cc0d73c17a36df771478d Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 14 May 2024 23:06:35 +0800 Subject: [PATCH] Update --- .../op_register/numba_approach/__init__.py | 68 ++--------------- .../numba_approach/cpu_translation.py | 76 +++++++++++++++++++ .../tests/test_numba_approach.py | 49 ++++++++++++ 3 files changed, 131 insertions(+), 62 deletions(-) create mode 100644 brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 8d5cd3de..1ad489cf 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -6,16 +6,15 @@ import jax from jax.interpreters import xla, batching, ad, mlir -from jax.lib import xla_client + from jax.tree_util import tree_map -from jaxlib.hlo_helpers import custom_call from brainpy._src.dependency_check import import_numba from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject -from brainpy._src.math.op_register.utils import _shape_to_layout + from brainpy.errors import PackageMissingError -from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba +from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba, _numba_mlir_cpu_translation_rule numba = import_numba(error_if_not_found=False) if numba is not None: @@ -224,62 +223,6 @@ def abs_eval_rule(*input_shapes, **info): return prim -def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs): - # output information - outs = ctx.avals_out - output_shapes = tuple([out.shape for out in outs]) - output_dtypes = tuple([out.dtype for out in outs]) - output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) - result_types = [mlir.aval_to_ir_type(out) for out in outs] - - # input information - avals_in = ctx.avals_in - input_layouts = [_shape_to_layout(a.shape) for a in avals_in] - input_dtypes = tuple(inp.dtype for inp in avals_in) - input_shapes = tuple(inp.shape for inp in avals_in) - - # compiling function - code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes, - output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) - args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])' - for i in range(len(input_shapes))] - if len(output_shapes) > 1: - args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])' - for i in range(len(output_shapes))] - sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) - else: - args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'] - sig = types.void(types.voidptr, types.CPointer(types.voidptr)) - args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))] - code_string = ''' - def numba_cpu_custom_call_target(output_ptrs, input_ptrs): - {args_out} - {args_in} - func_to_call({args_call}) - '''.format(args_out="\n ".join(args_out), args_in="\n ".join(args_in), args_call=", ".join(args_call)) - - if debug: - print(code_string) - exec(compile(code_string.strip(), '', 'exec'), code_scope) - new_f = code_scope['numba_cpu_custom_call_target'] - - # register - xla_c_rule = cfunc(sig)(new_f) - target_name = f'numba_custom_call_{str(xla_c_rule.address)}' - capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) - xla_client.register_custom_call_target(target_name, capsule, "cpu") - - # call - return custom_call( - call_target_name=target_name, - operands=ins, - operand_layouts=list(input_layouts), - result_layouts=list(output_layouts), - result_types=list(result_types), - has_side_effect=False, - ).results - - def register_op_with_numba_mlir( op_name: str, cpu_func: Callable, @@ -329,8 +272,9 @@ def abs_eval_rule(*input_shapes, **info): prim.def_abstract_eval(abs_eval_rule) prim.def_impl(partial(xla.apply_primitive, prim)) - def cpu_translation_rule(ctx, *ins, **kwargs): - return _numba_mlir_cpu_translation_rule(cpu_func, False, ctx, *ins, **kwargs) + cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule, + cpu_func, + True) mlir.register_lowering(prim, cpu_translation_rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py index 4b06effd..363ce6b1 100644 --- a/brainpy/_src/math/op_register/numba_approach/cpu_translation.py +++ b/brainpy/_src/math/op_register/numba_approach/cpu_translation.py @@ -5,8 +5,11 @@ from jax import dtypes, numpy as jnp from jax.core import ShapedArray from jax.lib import xla_client +from jaxlib.hlo_helpers import custom_call +from jax.interpreters import mlir from brainpy._src.dependency_check import import_numba +from brainpy._src.math.op_register.utils import _shape_to_layout numba = import_numba(error_if_not_found=False) ctypes.pythonapi.PyCapsule_New.argtypes = [ @@ -19,6 +22,7 @@ __all__ = [ '_cpu_translation', 'compile_cpu_signature_with_numba', + '_numba_mlir_cpu_translation_rule', ] if numba is not None: @@ -150,3 +154,75 @@ def compile_cpu_signature_with_numba( if multiple_results else output_layouts[0]) return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts + + +def _numba_mlir_cpu_translation_rule( + cpu_func, + debug, + ctx, + *ins, + **kwargs +): + # output information + outs = ctx.avals_out + output_shapes = tuple([out.shape for out in outs]) + output_dtypes = tuple([out.dtype for out in outs]) + output_layouts = tuple([_shape_to_layout(out.shape) for out in outs]) + result_types = [mlir.aval_to_ir_type(out) for out in outs] + + # input information + avals_in = ctx.avals_in + input_layouts = [_shape_to_layout(a.shape) for a in avals_in] + input_dtypes = tuple(inp.dtype for inp in avals_in) + input_shapes = tuple(inp.shape for inp in avals_in) + + # compiling function + code_scope = dict(func_to_call=cpu_func, input_shapes=input_shapes, input_dtypes=input_dtypes, + output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray) + if len(input_shapes) > 1: + args_in = [ + f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),' + for i in range(len(input_shapes)) + ] + args_in = '(\n ' + "\n ".join(args_in) + '\n )' + else: + args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])' + if len(output_shapes) > 1: + args_out = [ + f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),' + for i in range(len(output_shapes)) + ] + args_out = '(\n ' + "\n ".join(args_out) + '\n )' + sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr)) + else: + args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])' + sig = types.void(types.voidptr, types.CPointer(types.voidptr)) + # args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))] + code_string = ''' +def numba_cpu_custom_call_target(output_ptrs, input_ptrs): + args_out = {args_out} + args_in = {args_in} + func_to_call(args_out, args_in) + '''.format(args_in=args_in, + args_out=args_out) + + if debug: + print(code_string) + exec(compile(code_string.strip(), '', 'exec'), code_scope) + new_f = code_scope['numba_cpu_custom_call_target'] + + # register + xla_c_rule = cfunc(sig)(new_f) + target_name = f'numba_custom_call_{str(xla_c_rule.address)}' + capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None) + xla_client.register_custom_call_target(target_name, capsule, "cpu") + + # call + return custom_call( + call_target_name=target_name, + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + has_side_effect=False, + ).results diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py new file mode 100644 index 00000000..e1bed7de --- /dev/null +++ b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py @@ -0,0 +1,49 @@ +import jax.core +import pytest +from jax.core import ShapedArray + +import brainpy.math as bm +from brainpy._src.dependency_check import import_numba + +numba = import_numba(error_if_not_found=False) +if numba is None: + pytest.skip('no numba', allow_module_level=True) + +bm.set_platform('cpu') + + +def eval_shape(a): + b = ShapedArray(a.shape, dtype=a.dtype) + return b + +@numba.njit(parallel=True) +def con_compute(outs, ins): + b = outs + a = ins + b[:] = a + 1 + +def test_CustomOpByNumba_single_result(): + op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False) + print(op(bm.zeros(10))) + +def eval_shape2(a, b): + c = ShapedArray(a.shape, dtype=a.dtype) + d = ShapedArray(b.shape, dtype=b.dtype) + return c, d + +@numba.njit(parallel=True) +def con_compute2(outs, ins): + c = outs[0] # take out all the outputs + d = outs[1] + a = ins[0] # take out all the inputs + b = ins[1] + # c, d = outs + # a, b = ins + c[:] = a + 1 + d[:] = b * 2 + +def test_CustomOpByNumba_multiple_results(): + op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True) + print(op2(bm.zeros(10), bm.ones(10))) + +test_CustomOpByNumba_multiple_results() \ No newline at end of file