Skip to content

Commit

Permalink
Implement customized op with cupy JIT Kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 19, 2024
1 parent f9cba21 commit c8be3ee
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 34 deletions.
21 changes: 21 additions & 0 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
'import_numba',
'raise_numba_not_found',
'import_cupy',
'import_cupy_jit',
'raise_cupy_not_found',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
Expand All @@ -20,6 +21,7 @@
numba = None
taichi = None
cupy = None
cupy_jit = None
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

Expand Down Expand Up @@ -106,6 +108,25 @@ def import_cupy(error_if_not_found=True):
return cupy


def import_cupy_jit(error_if_not_found=True):
"""
Internal API to import cupy.
If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global cupy_jit
if cupy_jit is None:
try:
from cupyx import jit as cupy_jit
except ModuleNotFoundError:
if error_if_not_found:
raise_cupy_not_found()
else:
return None
return cupy_jit


def raise_cupy_not_found():
raise ModuleNotFoundError(cupy_install_info)

Expand Down
18 changes: 12 additions & 6 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,27 @@
import numpy as np
from jax.interpreters import xla, batching, ad, mlir

from brainpy._src.dependency_check import import_numba
from brainpy._src.dependency_check import import_numba, import_cupy_jit
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject

if jax.__version__ >= '0.4.16':
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import register_cupy_mlir_gpu_translation_rule as register_cupy_gpu_translation_rule
from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import register_cupy_xla_gpu_translation_rule as register_cupy_gpu_translation_rule
from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

numba = import_numba(error_if_not_found=False)
cp_jit = import_cupy_jit(error_if_not_found=False)

__all__ = [
'XLACustomOp',
Expand Down Expand Up @@ -127,14 +130,17 @@ def __init__(
gpu_checked = False
if gpu_kernel is None:
gpu_checked = True
elif isinstance(gpu_kernel, str): # cupy
register_cupy_gpu_translation_rule(self.primitive, gpu_kernel)
elif isinstance(gpu_kernel, str): # cupy RawModule
register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel
register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
if not gpu_checked:
raise ValueError(f'"gpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}')
raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')

# batching rule
if batching_translation is None:
Expand Down
179 changes: 173 additions & 6 deletions brainpy/_src/math/op_register/cupy_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
from jaxlib.hlo_helpers import custom_call

from brainpy._src.dependency_check import (import_cupy,
import_cupy_jit,
import_brainpylib_gpu_ops)
from brainpy._src.math.op_register.utils import _shape_to_layout
from brainpy.errors import PackageMissingError

cp = import_cupy(error_if_not_found=False)
cp_jit = import_cupy_jit(error_if_not_found=False)

# convert type to number
type_number_map = {
Expand Down Expand Up @@ -71,7 +73,7 @@ def _preprocess_kernel_call_gpu(
return opaque


def _cupy_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
grid = kwargs.get('grid', None)
block = kwargs.get('block', None)
shared_mem = kwargs.get('shared_mem', 0)
Expand Down Expand Up @@ -103,11 +105,11 @@ def _cupy_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
)


def register_cupy_xla_gpu_translation_rule(primitive, gpu_kernel):
xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_xla_gpu_translation_rule, gpu_kernel)
def register_cupy_raw_module_xla_gpu_translation_rule(primitive, gpu_kernel):
xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_raw_module_xla_gpu_translation_rule, gpu_kernel)


def _cupy_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
grid = kwargs.get('grid', None)
block = kwargs.get('block', None)
shared_mem = kwargs.get('shared_mem', 0)
Expand Down Expand Up @@ -140,9 +142,174 @@ def _cupy_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
).results


def register_cupy_mlir_gpu_translation_rule(primitive, gpu_kernel):
def register_cupy_raw_module_mlir_gpu_translation_rule(primitive, gpu_kernel):
if cp is None:
raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule')

rule = partial(_cupy_mlir_gpu_translation_rule, gpu_kernel)
rule = partial(_cupy_raw_module_mlir_gpu_translation_rule, gpu_kernel)
mlir.register_lowering(primitive, rule, platform='gpu')


def get_jit_kernel_xla(kernel, c, *ins, outs):
# check if compiled
in_types = []
for x in ins:
x = c.get_shape(x)
if len(x.dimensions()) != 0:
t = cp_jit._cuda_types.CArray(dtype=x.element_type(), ndim=len(x.dimensions()), is_c_contiguous=True,
index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=x.element_type())
in_types.append(t)
for x in outs:
if x.ndim != 0:
t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=x.dtype)
in_types.append(t)
in_types = tuple(in_types)
device_id = cp.cuda.get_device_id()
kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None))

if kern is None:
result = kernel._cached_codes.get(in_types)
if result is None:
result = cp_jit._compile.transpile(
kernel._func,
['extern "C"', '__global__'],
'cuda',
in_types,
cp_jit._cuda_types.void,
)
kernel._cached_codes[in_types] = result
fname = result.func_name
enable_cg = result.enable_cooperative_groups
options = result.options
backend = result.backend
if backend == 'nvcc':
options += ('-DCUPY_JIT_NVCC',)
jitify = result.jitify
module = cp._core.core.compile_with_cache(
source=result.code,
options=options,
backend=backend,
jitify=jitify,
)
kern = module.get_function(fname)
kernel._cache[(in_types, device_id)] = (kern, enable_cg)

return kern


def get_jit_kernel_mlir(kernel, c):
# check if compiled
in_types = []
for x in c.avals_in:
if x.ndim != 0:
t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=x.dtype)
in_types.append(t)
for x in c.avals_out:
if x.ndim != 0:
t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=x.dtype)
in_types.append(t)
in_types = tuple(in_types)
device_id = cp.cuda.get_device_id()
kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None))

if kern is None:
result = kernel._cached_codes.get(in_types)
if result is None:
result = cp_jit._compile.transpile(
kernel._func,
['extern "C"', '__global__'],
'cuda',
in_types,
cp_jit._cuda_types.void,
)
kernel._cached_codes[in_types] = result
fname = result.func_name
enable_cg = result.enable_cooperative_groups
options = result.options
backend = result.backend
if backend == 'nvcc':
options += ('-DCUPY_JIT_NVCC',)
jitify = result.jitify
module = cp._core.core.compile_with_cache(
source=result.code,
options=options,
backend=backend,
jitify=jitify,
)
kern = module.get_function(fname)
kernel._cache[(in_types, device_id)] = (kern, enable_cg)

return kern


def _cupy_jit_kernel_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
kernel_func = get_jit_kernel_xla(kernel, c, *ins, outs=kwargs['outs'])
grid = kwargs.get('grid', None)
block = kwargs.get('block', None)
shared_mem = kwargs.get('shared_mem', 0)
if grid is None or block is None:
raise ValueError('The grid and block should be specified for the cupy kernel.')

# preprocess
import_brainpylib_gpu_ops()
opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs'])

# create custom call
return xla_client.ops.CustomCallWithLayout(
c,
b'cupy_kernel_call_gpu',
operands=ins,
operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins),
shape_with_layout=xla_client.Shape.tuple_shape(
[xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape))
for value in kwargs['outs']]
),
opaque=opaque,
)


def register_cupy_jit_kernel_xla_gpu_translation_rule(primitive, gpu_kernel):
xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_jit_kernel_xla_gpu_translation_rule, gpu_kernel)


def _cupy_jit_kernel_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
kernel_func = get_jit_kernel_mlir(kernel, c)
grid = kwargs.get('grid', None)
block = kwargs.get('block', None)
shared_mem = kwargs.get('shared_mem', 0)
if grid is None or block is None:
raise ValueError('The grid and block should be specified for the cupy kernel.')

# preprocess
import_brainpylib_gpu_ops()
opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs'])

input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out]

return custom_call(
call_target_name='cupy_kernel_call_gpu',
operands=ins,
operand_layouts=list(input_layouts),
result_layouts=list(output_layouts),
result_types=list(result_types),
backend_config=opaque,
has_side_effect=False,
).results


def register_cupy_jit_kernel_mlir_gpu_translation_rule(primitive, gpu_kernel):
if cp is None:
raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule')

rule = partial(_cupy_jit_kernel_mlir_gpu_translation_rule, gpu_kernel)
mlir.register_lowering(primitive, rule, platform='gpu')
45 changes: 23 additions & 22 deletions brainpy/_src/math/op_register/tests/test_cupy_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
import pytest

import brainpy.math as bm
from brainpy._src.dependency_check import import_cupy
from brainpy._src.dependency_check import import_cupy, import_cupy_jit

cp = import_cupy(error_if_not_found=False)
cp_jit = import_cupy_jit(error_if_not_found=False)
if cp is None:
pytest.skip('no cupy', allow_module_level=True)

bm.set_platform('gpu')


def test_cupy_based():
# Raw Module

source_code = r'''
extern "C"{
Expand All @@ -28,33 +30,32 @@ def test_cupy_based():
'''
N = 10
x1 = bm.ones((N, N))
# x1_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(x1)))
x2 = bm.ones((N, N))
# x2_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(x2)))
y = bm.zeros((N, N))
# y_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(y)))

# mod = cp.RawModule(code=source_code)
# kernel = mod.get_function('kernel')
# y = kernel((N,), (N,), (x1_cp, x2_cp, N**2, y_cp))
# print(y_cp)

prim = bm.XLACustomOp(gpu_kernel=source_code)
prim1 = bm.XLACustomOp(gpu_kernel=source_code)

# n = jnp.asarray([N**2,], dtype=jnp.int32)

y = prim(x1, x2, N ** 2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0]
y = prim1(x1, x2, N ** 2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0]

print(y)
assert jnp.allclose(y, x1 + x2)

# N = 10
# x1 = cp.arange(N**2, dtype=cp.float32).reshape(N, N)
# x2 = cp.ones((N, N), dtype=cp.float32)
# y = cp.zeros((N, N), dtype=cp.float32)
# ker_sum((N,), (N,), (x1, x2, y, N**2)) # y = x1 + x2
# assert cp.allclose(y, x1 + x2)
# ker_times((N,), (N,), (x1, x2, y, N**2)) # y = x1 * x2
# assert cp.allclose(y, x1 * x2)
# JIT Kernel
@cp_jit.rawkernel()
def elementwise_copy(x, size, y):
tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x
ntid = cp_jit.gridDim.x * cp_jit.blockDim.x
for i in range(tid, size, ntid):
y[i] = x[i]

size = 100
x = bm.ones((size,))

prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)

y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=jnp.float32)])[0]

print(y)
assert jnp.allclose(y, x)

# test_cupy_based()

0 comments on commit c8be3ee

Please sign in to comment.