Skip to content

Commit

Permalink
[math] Fix CustomOpByNumba on multiple_results=True (#671)
Browse files Browse the repository at this point in the history
* [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24

* Update dependency_check.py

* Update dependency_check.py

* Update requirements-dev.txt

* Update

* Update operator_custom_with_numba.ipynb

* Update __init__.py

* Update dependency_check.py

* Update __init__.py

* Fix

* Update docs

* Update operator_custom_with_taichi.ipynb
  • Loading branch information
Routhleck authored May 15, 2024
1 parent 4d4eea5 commit e3a854a
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 79 deletions.
6 changes: 3 additions & 3 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

taichi_install_info = (f'We need taichi=={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} through:\n\n'
'> pip install taichi==1.7.0')
taichi_install_info = (f'We need taichi>={_minimal_taichi_version}. '
f'Currently you can install taichi=={_minimal_taichi_version} by pip . \n'
'> pip install taichi -U')
numba_install_info = ('We need numba. Please install numba by pip . \n'
'> pip install numba')
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
Expand Down
68 changes: 7 additions & 61 deletions brainpy/_src/math/op_register/numba_approach/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
# -*- coding: utf-8 -*-
import ctypes
import ctypes
from functools import partial
from typing import Callable
from typing import Union, Sequence

import jax
from jax.interpreters import xla, batching, ad, mlir
from jax.lib import xla_client

from jax.tree_util import tree_map
from jaxlib.hlo_helpers import custom_call

from brainpy._src.dependency_check import import_numba
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy._src.math.op_register.utils import _shape_to_layout

from brainpy.errors import PackageMissingError
from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba
from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba, _numba_mlir_cpu_translation_rule

numba = import_numba(error_if_not_found=False)
if numba is not None:
Expand Down Expand Up @@ -224,62 +225,6 @@ def abs_eval_rule(*input_shapes, **info):
return prim


def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs):
# output information
outs = ctx.avals_out
output_shapes = tuple([out.shape for out in outs])
output_dtypes = tuple([out.dtype for out in outs])
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
result_types = [mlir.aval_to_ir_type(out) for out in outs]

# input information
avals_in = ctx.avals_in
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
input_dtypes = tuple(inp.dtype for inp in avals_in)
input_shapes = tuple(inp.shape for inp in avals_in)

# compiling function
code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes,
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
for i in range(len(input_shapes))]
if len(output_shapes) > 1:
args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
for i in range(len(output_shapes))]
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
else:
args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))]
code_string = '''
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
{args_out}
{args_in}
func_to_call({args_call})
'''.format(args_out="\n ".join(args_out), args_in="\n ".join(args_in), args_call=", ".join(args_call))

if debug:
print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']

# register
xla_c_rule = cfunc(sig)(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# call
return custom_call(
call_target_name=target_name,
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
has_side_effect=False,
).results


def register_op_with_numba_mlir(
op_name: str,
cpu_func: Callable,
Expand Down Expand Up @@ -329,8 +274,9 @@ def abs_eval_rule(*input_shapes, **info):
prim.def_abstract_eval(abs_eval_rule)
prim.def_impl(partial(xla.apply_primitive, prim))

def cpu_translation_rule(ctx, *ins, **kwargs):
return _numba_mlir_cpu_translation_rule(cpu_func, False, ctx, *ins, **kwargs)
cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule,
cpu_func,
False)

mlir.register_lowering(prim, cpu_translation_rule, platform='cpu')

Expand Down
76 changes: 76 additions & 0 deletions brainpy/_src/math/op_register/numba_approach/cpu_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
from jax import dtypes, numpy as jnp
from jax.core import ShapedArray
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call
from jax.interpreters import mlir

from brainpy._src.dependency_check import import_numba
from brainpy._src.math.op_register.utils import _shape_to_layout

numba = import_numba(error_if_not_found=False)
ctypes.pythonapi.PyCapsule_New.argtypes = [
Expand All @@ -19,6 +22,7 @@
__all__ = [
'_cpu_translation',
'compile_cpu_signature_with_numba',
'_numba_mlir_cpu_translation_rule',
]

if numba is not None:
Expand Down Expand Up @@ -150,3 +154,75 @@ def compile_cpu_signature_with_numba(
if multiple_results else
output_layouts[0])
return target_name, tuple(inputs) + tuple(info_inputs), input_layouts, output_layouts


def _numba_mlir_cpu_translation_rule(
cpu_func,
debug,
ctx,
*ins,
**kwargs
):
# output information
outs = ctx.avals_out
output_shapes = tuple([out.shape for out in outs])
output_dtypes = tuple([out.dtype for out in outs])
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
result_types = [mlir.aval_to_ir_type(out) for out in outs]

# input information
avals_in = ctx.avals_in
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
input_dtypes = tuple(inp.dtype for inp in avals_in)
input_shapes = tuple(inp.shape for inp in avals_in)

# compiling function
code_scope = dict(func_to_call=cpu_func, input_shapes=input_shapes, input_dtypes=input_dtypes,
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
if len(input_shapes) > 1:
args_in = [
f'carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}]),'
for i in range(len(input_shapes))
]
args_in = '(\n ' + "\n ".join(args_in) + '\n )'
else:
args_in = 'carray(input_ptrs[0], input_shapes[0], dtype=input_dtypes[0])'
if len(output_shapes) > 1:
args_out = [
f'carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}]),'
for i in range(len(output_shapes))
]
args_out = '(\n ' + "\n ".join(args_out) + '\n )'
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
else:
args_out = 'carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])'
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
# args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))]
code_string = '''
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
args_out = {args_out}
args_in = {args_in}
func_to_call(args_out, args_in)
'''.format(args_in=args_in,
args_out=args_out)

if debug:
print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']

# register
xla_c_rule = cfunc(sig)(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# call
return custom_call(
call_target_name=target_name,
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
has_side_effect=False,
).results
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import jax.core
import pytest
from jax.core import ShapedArray

import brainpy.math as bm
from brainpy._src.dependency_check import import_numba

numba = import_numba(error_if_not_found=False)
if numba is None:
pytest.skip('no numba', allow_module_level=True)

bm.set_platform('cpu')


def eval_shape(a):
b = ShapedArray(a.shape, dtype=a.dtype)
return b

@numba.njit(parallel=True)
def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1

def test_CustomOpByNumba_single_result():
op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
print(op(bm.zeros(10)))

def eval_shape2(a, b):
c = ShapedArray(a.shape, dtype=a.dtype)
d = ShapedArray(b.shape, dtype=b.dtype)
return c, d

def con_compute2(outs, ins):
c = outs[0] # take out all the outputs
d = outs[1]
a = ins[0] # take out all the inputs
b = ins[1]
# c, d = outs
# a, b = ins
c[:] = a + 1
d[:] = b * 2

def test_CustomOpByNumba_multiple_results():
op2 = bm.CustomOpByNumba(eval_shape2, con_compute2, multiple_results=True)
print(op2(bm.zeros(10), bm.ones(10)))

test_CustomOpByNumba_multiple_results()
Loading

0 comments on commit e3a854a

Please sign in to comment.