Skip to content

Commit

Permalink
[math] Update CustomOpByNumba to support JAX version >= 0.4.24 (#669)
Browse files Browse the repository at this point in the history
* [math] Update `CustomOpByNumba` to support JAX version >= 0.4.24

* Update dependency_check.py

* Update dependency_check.py

* Update requirements-dev.txt
  • Loading branch information
Routhleck authored May 14, 2024
1 parent 78036e9 commit 4d4eea5
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 25 deletions.
10 changes: 7 additions & 3 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,13 @@ def import_taichi(error_if_not_found=True):

if taichi is None:
return None
if taichi.__version__ != _minimal_taichi_version:
raise RuntimeError(taichi_install_info)
return taichi
taichi_version = taichi.__version__[0] * 10000 + taichi.__version__[1] * 100 + taichi.__version__[2]
minimal_taichi_version = _minimal_taichi_version[0] * 10000 + _minimal_taichi_version[1] * 100 + \
_minimal_taichi_version[2]
if taichi_version >= minimal_taichi_version:
return taichi
else:
raise ModuleNotFoundError(taichi_install_info)


def raise_taichi_not_found(*args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/math/op_register/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .numba_approach import (CustomOpByNumba,
register_op_with_numba,
register_op_with_numba_xla,
compile_cpu_signature_with_numba)
from .base import XLACustomOp
from .utils import register_general_batching
Expand Down
187 changes: 166 additions & 21 deletions brainpy/_src/math/op_register/numba_approach/__init__.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,41 @@
# -*- coding: utf-8 -*-

import ctypes
from functools import partial
from typing import Callable
from typing import Union, Sequence

import jax
from jax.interpreters import xla, batching, ad
from jax.interpreters import xla, batching, ad, mlir
from jax.lib import xla_client
from jax.tree_util import tree_map
from jaxlib.hlo_helpers import custom_call

from brainpy._src.dependency_check import import_numba
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject
from brainpy._src.math.op_register.utils import _shape_to_layout
from brainpy.errors import PackageMissingError
from .cpu_translation import _cpu_translation, compile_cpu_signature_with_numba

numba = import_numba(error_if_not_found=False)

if numba is not None:
from numba import types, carray, cfunc

__all__ = [
'CustomOpByNumba',
'register_op_with_numba',
'register_op_with_numba_xla',
'compile_cpu_signature_with_numba',
]


def _transform_to_shapedarray(a):
return jax.core.ShapedArray(a.shape, a.dtype)


def convert_shapedarray_to_shapedtypestruct(shaped_array):
return jax.ShapeDtypeStruct(shape=shaped_array.shape, dtype=shaped_array.dtype)


class CustomOpByNumba(BrainPyObject):
"""Creating a XLA custom call operator with Numba JIT on CPU backend.
Expand Down Expand Up @@ -61,20 +73,35 @@ def __init__(
# abstract evaluation function
if eval_shape is None:
raise ValueError('Must provide "eval_shape" for abstract evaluation.')
self.eval_shape = eval_shape

# cpu function
cpu_func = con_compute

# register OP
self.op = register_op_with_numba(
self.name,
cpu_func=cpu_func,
out_shapes=eval_shape,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)
if jax.__version__ > '0.4.23':
self.op_method = 'mlir'
self.op = register_op_with_numba_mlir(
self.name,
cpu_func=cpu_func,
out_shapes=eval_shape,
gpu_func_translation=None,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)
else:
self.op_method = 'xla'
self.op = register_op_with_numba_xla(
self.name,
cpu_func=cpu_func,
out_shapes=eval_shape,
batching_translation=batching_translation,
jvp_translation=jvp_translation,
transpose_translation=transpose_translation,
multiple_results=multiple_results,
)

def __call__(self, *args, **kwargs):
args = tree_map(lambda a: a.value if isinstance(a, Array) else a,
Expand All @@ -85,7 +112,7 @@ def __call__(self, *args, **kwargs):
return res


def register_op_with_numba(
def register_op_with_numba_xla(
op_name: str,
cpu_func: Callable,
out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]],
Expand Down Expand Up @@ -132,13 +159,6 @@ def register_op_with_numba(
A JAX Primitive object.
"""

if jax.__version__ > '0.4.23':
raise RuntimeError(f'{CustomOpByNumba.__name__} and {register_op_with_numba.__name__} are '
f'only supported in JAX version <= 0.4.23. \n'
f'However, you can use brainpy.math.XLACustomOp to create a custom op with numba syntax. '
f'For more information, please refer to the documentation: '
f'https://brainpy.readthedocs.io/en/latest/tutorial_advanced/operator_custom_with_taichi.html.')

if numba is None:
raise PackageMissingError.by_purpose('numba', 'custom op with numba')

Expand Down Expand Up @@ -202,3 +222,128 @@ def abs_eval_rule(*input_shapes, **info):
ad.primitive_transposes[prim] = transpose_translation

return prim


def _numba_mlir_cpu_translation_rule(kernel, debug: bool, ctx, *ins, **kwargs):
# output information
outs = ctx.avals_out
output_shapes = tuple([out.shape for out in outs])
output_dtypes = tuple([out.dtype for out in outs])
output_layouts = tuple([_shape_to_layout(out.shape) for out in outs])
result_types = [mlir.aval_to_ir_type(out) for out in outs]

# input information
avals_in = ctx.avals_in
input_layouts = [_shape_to_layout(a.shape) for a in avals_in]
input_dtypes = tuple(inp.dtype for inp in avals_in)
input_shapes = tuple(inp.shape for inp in avals_in)

# compiling function
code_scope = dict(func_to_call=kernel, input_shapes=input_shapes, input_dtypes=input_dtypes,
output_shapes=output_shapes, output_dtypes=output_dtypes, carray=carray)
args_in = [f'in{i} = carray(input_ptrs[{i}], input_shapes[{i}], dtype=input_dtypes[{i}])'
for i in range(len(input_shapes))]
if len(output_shapes) > 1:
args_out = [f'out{i} = carray(output_ptrs[{i}], output_shapes[{i}], dtype=output_dtypes[{i}])'
for i in range(len(output_shapes))]
sig = types.void(types.CPointer(types.voidptr), types.CPointer(types.voidptr))
else:
args_out = [f'out0 = carray(output_ptrs, output_shapes[0], dtype=output_dtypes[0])']
sig = types.void(types.voidptr, types.CPointer(types.voidptr))
args_call = [f'out{i}' for i in range(len(output_shapes))] + [f'in{i}' for i in range(len(input_shapes))]
code_string = '''
def numba_cpu_custom_call_target(output_ptrs, input_ptrs):
{args_out}
{args_in}
func_to_call({args_call})
'''.format(args_out="\n ".join(args_out), args_in="\n ".join(args_in), args_call=", ".join(args_call))

if debug:
print(code_string)
exec(compile(code_string.strip(), '', 'exec'), code_scope)
new_f = code_scope['numba_cpu_custom_call_target']

# register
xla_c_rule = cfunc(sig)(new_f)
target_name = f'numba_custom_call_{str(xla_c_rule.address)}'
capsule = ctypes.pythonapi.PyCapsule_New(xla_c_rule.address, b"xla._CUSTOM_CALL_TARGET", None)
xla_client.register_custom_call_target(target_name, capsule, "cpu")

# call
return custom_call(
call_target_name=target_name,
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
has_side_effect=False,
).results


def register_op_with_numba_mlir(
op_name: str,
cpu_func: Callable,
out_shapes: Union[Callable, jax.core.ShapedArray, Sequence[jax.core.ShapedArray]],
gpu_func_translation: Callable = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
multiple_results: bool = False,
):
if numba is None:
raise PackageMissingError.by_purpose('numba', 'custom op with numba')

if out_shapes is None:
raise RuntimeError('out_shapes cannot be None. It can be a `ShapedArray` or '
'a sequence of `ShapedArray`. If it is a function, it takes as input the argument '
'shapes and dtypes and should return correct output shapes of `ShapedArray`.')

prim = jax.core.Primitive(op_name)
prim.multiple_results = multiple_results

from numba.core.dispatcher import Dispatcher
if not isinstance(cpu_func, Dispatcher):
cpu_func = numba.jit(fastmath=True, nopython=True)(cpu_func)

def abs_eval_rule(*input_shapes, **info):
if callable(out_shapes):
shapes = out_shapes(*input_shapes, **info)
else:
shapes = out_shapes

if isinstance(shapes, jax.core.ShapedArray):
assert not multiple_results, "multiple_results is True, while the abstract evaluation returns only one data."
elif isinstance(shapes, (tuple, list)):
assert multiple_results, "multiple_results is False, while the abstract evaluation returns multiple data."
for elem in shapes:
if not isinstance(elem, jax.core.ShapedArray):
raise ValueError(f'Elements in "out_shapes" must be instances of '
f'jax.abstract_arrays.ShapedArray, but we got '
f'{type(elem)}: {elem}')
else:
raise ValueError(f'Unknown type {type(shapes)}, only '
f'supports function, ShapedArray or '
f'list/tuple of ShapedArray.')
return shapes

prim.def_abstract_eval(abs_eval_rule)
prim.def_impl(partial(xla.apply_primitive, prim))

def cpu_translation_rule(ctx, *ins, **kwargs):
return _numba_mlir_cpu_translation_rule(cpu_func, False, ctx, *ins, **kwargs)

mlir.register_lowering(prim, cpu_translation_rule, platform='cpu')

if gpu_func_translation is not None:
mlir.register_lowering(prim, gpu_func_translation, platform='gpu')

if batching_translation is not None:
jax.interpreters.batching.primitive_batchers[prim] = batching_translation

if jvp_translation is not None:
jax.interpreters.ad.primitive_jvps[prim] = jvp_translation

if transpose_translation is not None:
jax.interpreters.ad.primitive_transposes[prim] = transpose_translation

return prim
18 changes: 18 additions & 0 deletions brainpy/_src/math/op_register/tests/test_numba_based.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax.core
import pytest
from jax.core import ShapedArray

import brainpy.math as bm
from brainpy._src.dependency_check import import_numba
Expand Down Expand Up @@ -35,3 +36,20 @@ def test_event_ELL():
call(1000)
call(100)
bm.clear_buffer_memory()

# CustomOpByNumba Test

def eval_shape(a):
b = ShapedArray(a.shape, dtype=a.dtype)
return b

@numba.njit(parallel=True)
def con_compute(outs, ins):
b = outs
a = ins
b[:] = a + 1

def test_CustomOpByNumba():
op = bm.CustomOpByNumba(eval_shape, con_compute, multiple_results=False)
print(op(bm.zeros(10)))
assert bm.allclose(op(bm.zeros(10)), bm.ones(10))

0 comments on commit 4d4eea5

Please sign in to comment.