Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Fix CustomOpByNumba on multiple_results=True #671

Merged
merged 13 commits into from
May 15, 2024
6 changes: 3 additions & 3 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
'> pip install taichi==1.7.0')
taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} by pip . \n'
'> pip install taichi -U')
numba_install_info = ('We need numba. Please install numba by pip . \n'
'> pip install numba')
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
Expand Down
68 changes: 7 additions & 61 deletions brainpy/_src/math/op_register/numba_approach/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# -*- coding: utf-8 -*-
import ctypes
import ctypes
from functools import partial
from typing import Callable
from typing import Union, Sequence

import jax
from jax.interpreters import xla, batching, ad, mlir
from jax.lib import xla_client

from jax.tree_util import tree_map
from jaxlib.hlo_helpers import custom_call

from brainpy._src.dependency_check import import_numba
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy._src.math.op_register.utils import _shape_to_layout

from brainpy.errors import PackageMissingError
from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba
from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba, _numba_mlir_cpu_translation_rule

numba = import_numba(error_if_not_found=False)
if numba is not None:
Expand Down Expand Up @@ -224,62 +225,6 @@ def abs_eval_rule(*input_shapes, **info):
return prim


def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs):
# output information
outs = ctx.avals_out
output_shapes = tuple([out.shape for out in outs])
output_dtypes = tuple([out.dtype for out in outs])
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
result_types = [mlir.aval_to_ir_type(out) for out in outs]

# input information
avals_in = ctx.avals_in
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
input_dtypes = tuple(inp.dtype for inp in avals_in)
input_shapes = tuple(inp.shape for inp in avals_in)

# compiling function
code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes,
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
for i in range(len(input_shapes))]
if len(output_shapes) > 1:
args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
for i in range(len(output_shapes))]
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
else:
args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))]
code_string = '''
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
{args_out}
{args_in}
func_to_call({args_call})
'''.format(args_out="\n ".join(args_out), args_in="\n ".join(args_in), args_call=", ".join(args_call))

if debug:
print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']

# register
xla_c_rule = cfunc(sig)(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# call
return custom_call(
call_target_name=target_name,
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
has_side_effect=False,
).results


def register_op_with_numba_mlir(
op_name: str,
cpu_func: Callable,
Expand Down Expand Up @@ -329,8 +274,9 @@ def abs_eval_rule(*input_shapes, **info):
prim.def_abstract_eval(abs_eval_rule)
prim.def_impl(partial(xla.apply_primitive, prim))

def cpu_translation_rule(ctx, *ins, **kwargs):
return _numba_mlir_cpu_translation_rule(cpu_func, False, ctx, *ins, **kwargs)
cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule,
cpu_func,
False)

mlir.register_lowering(prim, cpu_translation_rule, platform='cpu')

Expand Down
76 changes: 76 additions & 0 deletions brainpy/_src/math/op_register/numba_approach/cpu_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from jax import dtypes, numpy as jnp
from jax.core import ShapedArray
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
from jax.interpreters import mlir

from brainpy._src.dependency_check import import_numba
from brainpy._src.math.op_register.utils import _shape_to_layout

numba = import_numba(error_if_not_found=False)
ctypes.pythonapi.PyCapsule_New.argtypes = [
Expand All @@ -19,6 +22,7 @@
__all__ = [
'_cpu_translation',
'compile_cpu_signature_with_numba',
'_numba_mlir_cpu_translation_rule',
]

if numba is not None:
Expand Down Expand Up @@ -150,3 +154,75 @@ def compile_cpu_signature_with_numba(
if multiple_results else
output_layouts[0])
return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts


def _numba_mlir_cpu_translation_rule(
cpu_func,
debug,
ctx,
*ins,
**kwargs
):
# output information
outs = ctx.avals_out
output_shapes = tuple([out.shape for out in outs])
output_dtypes = tuple([out.dtype for out in outs])
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
result_types = [mlir.aval_to_ir_type(out) for out in outs]

# input information
avals_in = ctx.avals_in
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
input_dtypes = tuple(inp.dtype for inp in avals_in)
input_shapes = tuple(inp.shape for inp in avals_in)

# compiling function
code_scope = dict(func_to_call=cpu_func, input_shapes=input_shapes, input_dtypes=input_dtypes,
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
if len(input_shapes) > 1:
args_in = [
f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),'
for i in range(len(input_shapes))
]
args_in = '(\n ' + "\n ".join(args_in) + '\n )'
else:
args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])'
if len(output_shapes) > 1:
args_out = [
f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),'
for i in range(len(output_shapes))
]
args_out = '(\n ' + "\n ".join(args_out) + '\n )'
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
else:
args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
# args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))]
code_string = '''
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
args_out = {args_out}
args_in = {args_in}
func_to_call(args_out, args_in)
'''.format(args_in=args_in,
args_out=args_out)

if debug:
print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']

# register
xla_c_rule = cfunc(sig)(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# call
return custom_call(
call_target_name=target_name,
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
has_side_effect=False,
).results
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import jax.core
import pytest
from jax.core import ShapedArray

import brainpy.math as bm
from brainpy._src.dependency_check import import_numba

numba = import_numba(error_if_not_found=False)
if numba is None:
pytest.skip('no numba', allow_module_level=True)

bm.set_platform('cpu')


def eval_shape(a):
b = ShapedArray(a.shape, dtype=a.dtype)
return b

@numba.njit(parallel=True)
def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1

def test_CustomOpByNumba_single_result():
op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
print(op(bm.zeros(10)))

def eval_shape2(a, b):
c = ShapedArray(a.shape, dtype=a.dtype)
d = ShapedArray(b.shape, dtype=b.dtype)
return c, d

def con_compute2(outs, ins):
c = outs[0] # take out all the outputs
d = outs[1]
a = ins[0] # take out all the inputs
b = ins[1]
# c, d = outs
# a, b = ins
c[:] = a + 1
d[:] = b * 2

def test_CustomOpByNumba_multiple_results():
op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)
print(op2(bm.zeros(10), bm.ones(10)))

test_CustomOpByNumba_multiple_results()
Loading
Loading