Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 14, 2024
1 parent 7898658 commit e55a9b0
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 62 deletions.
68 changes: 6 additions & 62 deletions brainpy/_src/math/op_register/numba_approach/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@

import jax
from jax.interpreters import xla, batching, ad, mlir
from jax.lib import xla_client

from jax.tree_util import tree_map
from jaxlib.hlo_helpers import custom_call

from brainpy._src.dependency_check import import_numba
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy._src.math.op_register.utils import _shape_to_layout

from brainpy.errors import PackageMissingError
from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba
from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba, _numba_mlir_cpu_translation_rule

numba = import_numba(error_if_not_found=False)
if numba is not None:
Expand Down Expand Up @@ -224,62 +223,6 @@ def abs_eval_rule(*input_shapes, **info):
return prim


def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs):
# output information
outs = ctx.avals_out
output_shapes = tuple([out.shape for out in outs])
output_dtypes = tuple([out.dtype for out in outs])
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
result_types = [mlir.aval_to_ir_type(out) for out in outs]

# input information
avals_in = ctx.avals_in
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
input_dtypes = tuple(inp.dtype for inp in avals_in)
input_shapes = tuple(inp.shape for inp in avals_in)

# compiling function
code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes,
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
for i in range(len(input_shapes))]
if len(output_shapes) > 1:
args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
for i in range(len(output_shapes))]
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
else:
args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))]
code_string = '''
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
{args_out}
{args_in}
func_to_call({args_call})
'''.format(args_out="\n ".join(args_out), args_in="\n ".join(args_in), args_call=", ".join(args_call))

if debug:
print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']

# register
xla_c_rule = cfunc(sig)(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# call
return custom_call(
call_target_name=target_name,
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
has_side_effect=False,
).results


def register_op_with_numba_mlir(
op_name: str,
cpu_func: Callable,
Expand Down Expand Up @@ -329,8 +272,9 @@ def abs_eval_rule(*input_shapes, **info):
prim.def_abstract_eval(abs_eval_rule)
prim.def_impl(partial(xla.apply_primitive, prim))

def cpu_translation_rule(ctx, *ins, **kwargs):
return _numba_mlir_cpu_translation_rule(cpu_func, False, ctx, *ins, **kwargs)
cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule,
cpu_func,
True)

mlir.register_lowering(prim, cpu_translation_rule, platform='cpu')

Expand Down
76 changes: 76 additions & 0 deletions brainpy/_src/math/op_register/numba_approach/cpu_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from jax import dtypes, numpy as jnp
from jax.core import ShapedArray
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
from jax.interpreters import mlir

from brainpy._src.dependency_check import import_numba
from brainpy._src.math.op_register.utils import _shape_to_layout

numba = import_numba(error_if_not_found=False)
ctypes.pythonapi.PyCapsule_New.argtypes = [
Expand All @@ -19,6 +22,7 @@
__all__ = [
'_cpu_translation',
'compile_cpu_signature_with_numba',
'_numba_mlir_cpu_translation_rule',
]

if numba is not None:
Expand Down Expand Up @@ -150,3 +154,75 @@ def compile_cpu_signature_with_numba(
if multiple_results else
output_layouts[0])
return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts


def _numba_mlir_cpu_translation_rule(
cpu_func,
debug,
ctx,
*ins,
**kwargs
):
# output information
outs = ctx.avals_out
output_shapes = tuple([out.shape for out in outs])
output_dtypes = tuple([out.dtype for out in outs])
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
result_types = [mlir.aval_to_ir_type(out) for out in outs]

# input information
avals_in = ctx.avals_in
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
input_dtypes = tuple(inp.dtype for inp in avals_in)
input_shapes = tuple(inp.shape for inp in avals_in)

# compiling function
code_scope = dict(func_to_call=cpu_func, input_shapes=input_shapes, input_dtypes=input_dtypes,
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
if len(input_shapes) > 1:
args_in = [
f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),'
for i in range(len(input_shapes))
]
args_in = '(\n ' + "\n ".join(args_in) + '\n )'
else:
args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])'
if len(output_shapes) > 1:
args_out = [
f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),'
for i in range(len(output_shapes))
]
args_out = '(\n ' + "\n ".join(args_out) + '\n )'
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
else:
args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
# args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))]
code_string = '''
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
args_out = {args_out}
args_in = {args_in}
func_to_call(args_out, args_in)
'''.format(args_in=args_in,
args_out=args_out)

if debug:
print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']

# register
xla_c_rule = cfunc(sig)(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# call
return custom_call(
call_target_name=target_name,
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
has_side_effect=False,
).results
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import jax.core
import pytest
from jax.core import ShapedArray

import brainpy.math as bm
from brainpy._src.dependency_check import import_numba

numba = import_numba(error_if_not_found=False)
if numba is None:
pytest.skip('no numba', allow_module_level=True)

bm.set_platform('cpu')


def eval_shape(a):
b = ShapedArray(a.shape, dtype=a.dtype)
return b

@numba.njit(parallel=True)
def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1

def test_CustomOpByNumba_single_result():
op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
print(op(bm.zeros(10)))

def eval_shape2(a, b):
c = ShapedArray(a.shape, dtype=a.dtype)
d = ShapedArray(b.shape, dtype=b.dtype)
return c, d

@numba.njit(parallel=True)
def con_compute2(outs, ins):
c = outs[0] # take out all the outputs
d = outs[1]
a = ins[0] # take out all the inputs
b = ins[1]
# c, d = outs
# a, b = ins
c[:] = a + 1
d[:] = b * 2

def test_CustomOpByNumba_multiple_results():
op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)
print(op2(bm.zeros(10), bm.ones(10)))

test_CustomOpByNumba_multiple_results()

0 comments on commit e55a9b0

Please sign in to comment.