Skip to content

Commit

Permalink
[math] Add new customize operators with cupy (#653)
Browse files Browse the repository at this point in the history
* Update

* Implement cupy based customized operators and Need to be tested

* Fix bugs

* Update base.py

* Update dependency_check.py

* Format codes

* Implement customized op with cupy `JIT Kernel`

* Update docs

* small update

---------

Co-authored-by: Chaoming Wang <[email protected]>
  • Loading branch information
Routhleck and chaoming0625 authored Mar 20, 2024
1 parent 23b5ab9 commit 3866203
Show file tree
Hide file tree
Showing 7 changed files with 618 additions and 55 deletions.
50 changes: 50 additions & 0 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
'raise_taichi_not_found',
'import_numba',
'raise_numba_not_found',
'import_cupy',
'import_cupy_jit',
'raise_cupy_not_found',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]
Expand All @@ -17,6 +20,8 @@

numba = None
taichi = None
cupy = None
cupy_jit = None
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

Expand All @@ -25,6 +30,9 @@
'> pip install taichi==1.7.0')
numba_install_info = ('We need numba. Please install numba by pip . \n'
'> pip install numba')
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
'For CUDA v12.x > pip install cupy-cuda12x\n')
os.environ["TI_LOG_LEVEL"] = "error"


Expand Down Expand Up @@ -81,6 +89,48 @@ def raise_numba_not_found():
raise ModuleNotFoundError(numba_install_info)


def import_cupy(error_if_not_found=True):
"""
Internal API to import cupy.
If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global cupy
if cupy is None:
try:
import cupy as cupy
except ModuleNotFoundError:
if error_if_not_found:
raise_cupy_not_found()
else:
return None
return cupy


def import_cupy_jit(error_if_not_found=True):
"""
Internal API to import cupy.
If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global cupy_jit
if cupy_jit is None:
try:
from cupyx import jit as cupy_jit
except ModuleNotFoundError:
if error_if_not_found:
raise_cupy_not_found()
else:
return None
return cupy_jit


def raise_cupy_not_found():
raise ModuleNotFoundError(cupy_install_info)


def is_brainpylib_gpu_installed():
return False if brainpylib_gpu_ops is None else True

Expand Down
53 changes: 20 additions & 33 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional
from typing import Callable, Sequence, Tuple, Protocol, Optional, Union

import jax
import numpy as np
from jax.interpreters import xla, batching, ad, mlir

from brainpy._src.dependency_check import import_numba
from brainpy._src.dependency_check import import_numba, import_cupy_jit
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject

if jax.__version__ >= '0.4.16':
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

numba = import_numba(error_if_not_found=False)
cp_jit = import_cupy_jit(error_if_not_found=False)

__all__ = [
'XLACustomOp',
Expand All @@ -41,34 +46,10 @@ def dtype(self) -> np.dtype:
class XLACustomOp(BrainPyObject):
"""Creating a XLA custom call operator.
>>> import numba as nb
>>> import taichi as ti
>>> import numpy as np
>>> import jax
>>>
>>> @nb.njit
>>> def numba_cpu_fun(a, b, out_a, out_b):
>>> out_a[:] = a
>>> out_b[:] = b
>>>
>>> @ti.kernel
>>> def taichi_gpu_fun(a, b, out_a, out_b):
>>> for i in range(a.size):
>>> out_a[i] = a[i]
>>> for i in range(b.size):
>>> out_b[i] = b[i]
>>>
>>> # option 1
>>> prim = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun)
>>> a2, b2 = prim(np.random.random(1000), np.random.random(1000),
>>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32),
>>> jax.ShapeDtypeStruct(1000, dtype=np.float32)])
>>>
>>> # option 2
>>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun,
>>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype),
>>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)])
>>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000))
For more information, please refer to the tutorials above:
Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html
Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html
CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html
Args:
cpu_kernel: Callable. The function defines the computation on CPU backend.
Expand All @@ -83,7 +64,7 @@ class XLACustomOp(BrainPyObject):
def __init__(
self,
cpu_kernel: Callable = None,
gpu_kernel: Callable = None,
gpu_kernel: Union[Callable, str] = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
Expand Down Expand Up @@ -125,11 +106,17 @@ def __init__(
gpu_checked = False
if gpu_kernel is None:
gpu_checked = True
if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule
register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel
register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
if not gpu_checked:
raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}')
raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')

# batching rule
if batching_translation is None:
Expand Down
Loading

0 comments on commit 3866203

Please sign in to comment.