diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 997da4a8..1e106062 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -9,6 +9,7 @@ 'import_numba', 'raise_numba_not_found', 'import_cupy', + 'import_cupy_jit', 'raise_cupy_not_found', 'import_brainpylib_cpu_ops', 'import_brainpylib_gpu_ops', @@ -20,6 +21,7 @@ numba = None taichi = None cupy = None +cupy_jit = None brainpylib_cpu_ops = None brainpylib_gpu_ops = None @@ -106,6 +108,25 @@ def import_cupy(error_if_not_found=True): return cupy +def import_cupy_jit(error_if_not_found=True): + """ + Internal API to import cupy. + + If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global cupy_jit + if cupy_jit is None: + try: + from cupyx import jit as cupy_jit + except ModuleNotFoundError: + if error_if_not_found: + raise_cupy_not_found() + else: + return None + return cupy_jit + + def raise_cupy_not_found(): raise ModuleNotFoundError(cupy_install_info) diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index ca0624a4..1aae5b8b 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -5,7 +5,7 @@ import numpy as np from jax.interpreters import xla, batching, ad, mlir -from brainpy._src.dependency_check import import_numba +from brainpy._src.dependency_check import import_numba, import_cupy_jit from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject @@ -13,16 +13,19 @@ from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import register_cupy_mlir_gpu_translation_rule as register_cupy_gpu_translation_rule + from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) else: from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule) - from .cupy_based import register_cupy_xla_gpu_translation_rule as register_cupy_gpu_translation_rule + from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp numba = import_numba(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) __all__ = [ 'XLACustomOp', @@ -127,14 +130,17 @@ def __init__( gpu_checked = False if gpu_kernel is None: gpu_checked = True - elif isinstance(gpu_kernel, str): # cupy - register_cupy_gpu_translation_rule(self.primitive, gpu_kernel) + elif isinstance(gpu_kernel, str): # cupy RawModule + register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel) + gpu_checked = True + elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel + register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True if not gpu_checked: - raise ValueError(f'"gpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}') + raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/_src/math/op_register/cupy_based.py b/brainpy/_src/math/op_register/cupy_based.py index d346a878..ad566b9b 100644 --- a/brainpy/_src/math/op_register/cupy_based.py +++ b/brainpy/_src/math/op_register/cupy_based.py @@ -8,11 +8,13 @@ from jaxlib.hlo_helpers import custom_call from brainpy._src.dependency_check import (import_cupy, + import_cupy_jit, import_brainpylib_gpu_ops) from brainpy._src.math.op_register.utils import _shape_to_layout from brainpy.errors import PackageMissingError cp = import_cupy(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) # convert type to number type_number_map = { @@ -71,7 +73,7 @@ def _preprocess_kernel_call_gpu( return opaque -def _cupy_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): +def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): grid = kwargs.get('grid', None) block = kwargs.get('block', None) shared_mem = kwargs.get('shared_mem', 0) @@ -103,11 +105,11 @@ def _cupy_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): ) -def register_cupy_xla_gpu_translation_rule(primitive, gpu_kernel): - xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_xla_gpu_translation_rule, gpu_kernel) +def register_cupy_raw_module_xla_gpu_translation_rule(primitive, gpu_kernel): + xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_raw_module_xla_gpu_translation_rule, gpu_kernel) -def _cupy_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): +def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): grid = kwargs.get('grid', None) block = kwargs.get('block', None) shared_mem = kwargs.get('shared_mem', 0) @@ -140,9 +142,174 @@ def _cupy_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): ).results -def register_cupy_mlir_gpu_translation_rule(primitive, gpu_kernel): +def register_cupy_raw_module_mlir_gpu_translation_rule(primitive, gpu_kernel): if cp is None: raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule') - rule = partial(_cupy_mlir_gpu_translation_rule, gpu_kernel) + rule = partial(_cupy_raw_module_mlir_gpu_translation_rule, gpu_kernel) + mlir.register_lowering(primitive, rule, platform='gpu') + + +def get_jit_kernel_xla(kernel, c, *ins, outs): + # check if compiled + in_types = [] + for x in ins: + x = c.get_shape(x) + if len(x.dimensions()) != 0: + t = cp_jit._cuda_types.CArray(dtype=x.element_type(), ndim=len(x.dimensions()), is_c_contiguous=True, + index_32_bits=True) + else: + t = cp_jit._cuda_types.Scalar(dtype=x.element_type()) + in_types.append(t) + for x in outs: + if x.ndim != 0: + t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True) + else: + t = cp_jit._cuda_types.Scalar(dtype=x.dtype) + in_types.append(t) + in_types = tuple(in_types) + device_id = cp.cuda.get_device_id() + kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None)) + + if kern is None: + result = kernel._cached_codes.get(in_types) + if result is None: + result = cp_jit._compile.transpile( + kernel._func, + ['extern "C"', '__global__'], + 'cuda', + in_types, + cp_jit._cuda_types.void, + ) + kernel._cached_codes[in_types] = result + fname = result.func_name + enable_cg = result.enable_cooperative_groups + options = result.options + backend = result.backend + if backend == 'nvcc': + options += ('-DCUPY_JIT_NVCC',) + jitify = result.jitify + module = cp._core.core.compile_with_cache( + source=result.code, + options=options, + backend=backend, + jitify=jitify, + ) + kern = module.get_function(fname) + kernel._cache[(in_types, device_id)] = (kern, enable_cg) + + return kern + + +def get_jit_kernel_mlir(kernel, c): + # check if compiled + in_types = [] + for x in c.avals_in: + if x.ndim != 0: + t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True) + else: + t = cp_jit._cuda_types.Scalar(dtype=x.dtype) + in_types.append(t) + for x in c.avals_out: + if x.ndim != 0: + t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True) + else: + t = cp_jit._cuda_types.Scalar(dtype=x.dtype) + in_types.append(t) + in_types = tuple(in_types) + device_id = cp.cuda.get_device_id() + kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None)) + + if kern is None: + result = kernel._cached_codes.get(in_types) + if result is None: + result = cp_jit._compile.transpile( + kernel._func, + ['extern "C"', '__global__'], + 'cuda', + in_types, + cp_jit._cuda_types.void, + ) + kernel._cached_codes[in_types] = result + fname = result.func_name + enable_cg = result.enable_cooperative_groups + options = result.options + backend = result.backend + if backend == 'nvcc': + options += ('-DCUPY_JIT_NVCC',) + jitify = result.jitify + module = cp._core.core.compile_with_cache( + source=result.code, + options=options, + backend=backend, + jitify=jitify, + ) + kern = module.get_function(fname) + kernel._cache[(in_types, device_id)] = (kern, enable_cg) + + return kern + + +def _cupy_jit_kernel_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): + kernel_func = get_jit_kernel_xla(kernel, c, *ins, outs=kwargs['outs']) + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs']) + + # create custom call + return xla_client.ops.CustomCallWithLayout( + c, + b'cupy_kernel_call_gpu', + operands=ins, + operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), + shape_with_layout=xla_client.Shape.tuple_shape( + [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) + for value in kwargs['outs']] + ), + opaque=opaque, + ) + + +def register_cupy_jit_kernel_xla_gpu_translation_rule(primitive, gpu_kernel): + xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_jit_kernel_xla_gpu_translation_rule, gpu_kernel) + + +def _cupy_jit_kernel_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): + kernel_func = get_jit_kernel_mlir(kernel, c) + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs']) + + input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] + result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] + output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out] + + return custom_call( + call_target_name='cupy_kernel_call_gpu', + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + backend_config=opaque, + has_side_effect=False, + ).results + + +def register_cupy_jit_kernel_mlir_gpu_translation_rule(primitive, gpu_kernel): + if cp is None: + raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule') + + rule = partial(_cupy_jit_kernel_mlir_gpu_translation_rule, gpu_kernel) mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/tests/test_cupy_based.py b/brainpy/_src/math/op_register/tests/test_cupy_based.py index b7656c43..4bcc1323 100644 --- a/brainpy/_src/math/op_register/tests/test_cupy_based.py +++ b/brainpy/_src/math/op_register/tests/test_cupy_based.py @@ -3,16 +3,18 @@ import pytest import brainpy.math as bm -from brainpy._src.dependency_check import import_cupy +from brainpy._src.dependency_check import import_cupy, import_cupy_jit cp = import_cupy(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) if cp is None: pytest.skip('no cupy', allow_module_level=True) - bm.set_platform('gpu') def test_cupy_based(): + # Raw Module + source_code = r''' extern "C"{ @@ -28,33 +30,32 @@ def test_cupy_based(): ''' N = 10 x1 = bm.ones((N, N)) - # x1_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(x1))) x2 = bm.ones((N, N)) - # x2_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(x2))) - y = bm.zeros((N, N)) - # y_cp = cp.from_dlpack(jax.dlpack.to_dlpack(as_jax(y))) - - # mod = cp.RawModule(code=source_code) - # kernel = mod.get_function('kernel') - # y = kernel((N,), (N,), (x1_cp, x2_cp, N**2, y_cp)) - # print(y_cp) - - prim = bm.XLACustomOp(gpu_kernel=source_code) + prim1 = bm.XLACustomOp(gpu_kernel=source_code) # n = jnp.asarray([N**2,], dtype=jnp.int32) - y = prim(x1, x2, N ** 2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0] + y = prim1(x1, x2, N ** 2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0] print(y) assert jnp.allclose(y, x1 + x2) - # N = 10 - # x1 = cp.arange(N**2, dtype=cp.float32).reshape(N, N) - # x2 = cp.ones((N, N), dtype=cp.float32) - # y = cp.zeros((N, N), dtype=cp.float32) - # ker_sum((N,), (N,), (x1, x2, y, N**2)) # y = x1 + x2 - # assert cp.allclose(y, x1 + x2) - # ker_times((N,), (N,), (x1, x2, y, N**2)) # y = x1 * x2 - # assert cp.allclose(y, x1 * x2) + # JIT Kernel + @cp_jit.rawkernel() + def elementwise_copy(x, size, y): + tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x + ntid = cp_jit.gridDim.x * cp_jit.blockDim.x + for i in range(tid, size, ntid): + y[i] = x[i] + + size = 100 + x = bm.ones((size,)) + + prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy) + + y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=jnp.float32)])[0] + + print(y) + assert jnp.allclose(y, x) # test_cupy_based()