diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index b8bd6e99..1e106062 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -8,6 +8,9 @@ 'raise_taichi_not_found', 'import_numba', 'raise_numba_not_found', + 'import_cupy', + 'import_cupy_jit', + 'raise_cupy_not_found', 'import_brainpylib_cpu_ops', 'import_brainpylib_gpu_ops', ] @@ -17,6 +20,8 @@ numba = None taichi = None +cupy = None +cupy_jit = None brainpylib_cpu_ops = None brainpylib_gpu_ops = None @@ -25,6 +30,9 @@ '> pip install taichi==1.7.0') numba_install_info = ('We need numba. Please install numba by pip . \n' '> pip install numba') +cupy_install_info = ('We need cupy. Please install cupy by pip . \n' + 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n' + 'For CUDA v12.x > pip install cupy-cuda12x\n') os.environ["TI_LOG_LEVEL"] = "error" @@ -81,6 +89,48 @@ def raise_numba_not_found(): raise ModuleNotFoundError(numba_install_info) +def import_cupy(error_if_not_found=True): + """ + Internal API to import cupy. + + If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global cupy + if cupy is None: + try: + import cupy as cupy + except ModuleNotFoundError: + if error_if_not_found: + raise_cupy_not_found() + else: + return None + return cupy + + +def import_cupy_jit(error_if_not_found=True): + """ + Internal API to import cupy. + + If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global cupy_jit + if cupy_jit is None: + try: + from cupyx import jit as cupy_jit + except ModuleNotFoundError: + if error_if_not_found: + raise_cupy_not_found() + else: + return None + return cupy_jit + + +def raise_cupy_not_found(): + raise ModuleNotFoundError(cupy_install_info) + + def is_brainpylib_gpu_installed(): return False if brainpylib_gpu_ops is None else True diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index ca070a19..5af5a7e3 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -1,11 +1,11 @@ from functools import partial -from typing import Callable, Sequence, Tuple, Protocol, Optional +from typing import Callable, Sequence, Tuple, Protocol, Optional, Union import jax import numpy as np from jax.interpreters import xla, batching, ad, mlir -from brainpy._src.dependency_check import import_numba +from brainpy._src.dependency_check import import_numba, import_cupy_jit from brainpy._src.math.ndarray import Array from brainpy._src.math.object_transform.base import BrainPyObject @@ -13,14 +13,19 @@ from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule) + from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) else: from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule, register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule) + from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule, + register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule) from .utils import register_general_batching from brainpy._src.math.op_register.ad_support import defjvp numba = import_numba(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) __all__ = [ 'XLACustomOp', @@ -41,34 +46,10 @@ def dtype(self) -> np.dtype: class XLACustomOp(BrainPyObject): """Creating a XLA custom call operator. - >>> import numba as nb - >>> import taichi as ti - >>> import numpy as np - >>> import jax - >>> - >>> @nb.njit - >>> def numba_cpu_fun(a, b, out_a, out_b): - >>> out_a[:] = a - >>> out_b[:] = b - >>> - >>> @ti.kernel - >>> def taichi_gpu_fun(a, b, out_a, out_b): - >>> for i in range(a.size): - >>> out_a[i] = a[i] - >>> for i in range(b.size): - >>> out_b[i] = b[i] - >>> - >>> # option 1 - >>> prim = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun) - >>> a2, b2 = prim(np.random.random(1000), np.random.random(1000), - >>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32), - >>> jax.ShapeDtypeStruct(1000, dtype=np.float32)]) - >>> - >>> # option 2 - >>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun, - >>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype), - >>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)]) - >>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000)) + For more information, please refer to the tutorials above: + Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html + Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html + CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html Args: cpu_kernel: Callable. The function defines the computation on CPU backend. @@ -83,7 +64,7 @@ class XLACustomOp(BrainPyObject): def __init__( self, cpu_kernel: Callable = None, - gpu_kernel: Callable = None, + gpu_kernel: Union[Callable, str] = None, batching_translation: Callable = None, jvp_translation: Callable = None, transpose_translation: Callable = None, @@ -125,11 +106,17 @@ def __init__( gpu_checked = False if gpu_kernel is None: gpu_checked = True - if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi + elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule + register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel) + gpu_checked = True + elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel + register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel) + gpu_checked = True + elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi register_taichi_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True if not gpu_checked: - raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}') + raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}') # batching rule if batching_translation is None: diff --git a/brainpy/_src/math/op_register/cupy_based.py b/brainpy/_src/math/op_register/cupy_based.py new file mode 100644 index 00000000..ad6befec --- /dev/null +++ b/brainpy/_src/math/op_register/cupy_based.py @@ -0,0 +1,279 @@ +from functools import partial, reduce +from typing import List, Tuple + +import jax +import numpy as np +from jax.interpreters import xla, mlir +from jax.lib import xla_client +from jaxlib.hlo_helpers import custom_call + +from brainpy._src.dependency_check import (import_cupy, + import_cupy_jit, + import_brainpylib_gpu_ops) +from brainpy._src.math.op_register.utils import _shape_to_layout +from brainpy.errors import PackageMissingError + +cp = import_cupy(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) + +# convert type to number +type_number_map = { + int: 0, + float: 1, + bool: 2, + np.dtype('int32'): 0, + np.dtype('float32'): 1, + np.dtype('bool'): 2, + np.dtype('uint8'): 3, + np.dtype('uint16'): 4, + np.dtype('uint32'): 5, + np.dtype('uint64'): 6, + np.dtype('int8'): 7, + np.dtype('int16'): 8, + np.dtype('int64'): 9, + np.dtype('float16'): 10, + np.dtype('float64'): 11, +} + + +def _preprocess_kernel_call_gpu( + grid: Tuple[int], + block: Tuple[int], + func_ptr: int, + shared_mem: int, + *ins, + outs: List[jax.ShapeDtypeStruct], +): + grid = (grid + (1, 1))[:3] + block = (block + (1, 1))[:3] + in_num = len(ins) + out_num = len(outs) + in_out_num = [in_num, out_num] + + out_type_list = [0] * out_num + out_elem_count_list = [0] * out_num + + for i, value in enumerate(outs): + out_type_list[i] = type_number_map[value.dtype] + out_elem_count_list[i] = reduce(lambda x, y: x * y, value.shape) + + grid = ",".join(str(i) for i in grid) + block = ",".join(str(i) for i in block) + in_out_num_str = ",".join(str(i) for i in in_out_num) + out_type_list_str = ",".join(str(i) for i in out_type_list) + out_elem_count_list_str = ",".join(str(i) for i in out_elem_count_list) + + opaque = (bytes(str(func_ptr), encoding='utf-8') + b';' + + bytes(str(shared_mem), encoding='utf-8') + b';' + + bytes(in_out_num_str, encoding='utf-8') + b';' + + bytes(grid, encoding='utf-8') + b';' + + bytes(block, encoding='utf-8') + b';' + + bytes(out_type_list_str, encoding='utf-8') + b';' + + bytes(out_elem_count_list_str, encoding='utf-8') + b';') + return opaque + + +def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + # THE KEY: + # - using the kernel pointer at "kernel.kernel.ptr" + opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) + + # create custom call + return xla_client.ops.CustomCallWithLayout( + c, + b'cupy_kernel_call_gpu', + operands=ins, + operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), + shape_with_layout=xla_client.Shape.tuple_shape( + [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) + for value in kwargs['outs']] + ), + opaque=opaque, + ) + + +def register_cupy_raw_module_xla_gpu_translation_rule(primitive, gpu_kernel): + xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_raw_module_xla_gpu_translation_rule, gpu_kernel) + + +def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) + + input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] + result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] + output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out] + + return custom_call( + call_target_name='cupy_kernel_call_gpu', + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + backend_config=opaque, + has_side_effect=False, + ).results + + +def register_cupy_raw_module_mlir_gpu_translation_rule(primitive, gpu_kernel): + if cp is None: + raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule') + + rule = partial(_cupy_raw_module_mlir_gpu_translation_rule, gpu_kernel) + mlir.register_lowering(primitive, rule, platform='gpu') + + +def _to_cupy_array_or_scalar(dtype, ndim): + # THE KEY + # - using the cupy jit compiler to get the type + if ndim != 0: + t = cp_jit._cuda_types.CArray(dtype=dtype, + ndim=ndim, + is_c_contiguous=True, + index_32_bits=True) + else: + t = cp_jit._cuda_types.Scalar(dtype=dtype) + return t + + +def _compile_kernel_xla(kernel, in_types): + # THE KEY + # - get the kernel function from the cache + device_id = cp.cuda.get_device_id() + kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None)) + + if kern is None: + # THE KEY: + # - compile the kernel function + result = kernel._cached_codes.get(in_types) + if result is None: + result = cp_jit._compile.transpile( + kernel._func, + ['extern "C"', '__global__'], + 'cuda', + in_types, + cp_jit._cuda_types.void, + ) + kernel._cached_codes[in_types] = result + fname = result.func_name + enable_cg = result.enable_cooperative_groups + options = result.options + backend = result.backend + if backend == 'nvcc': + options += ('-DCUPY_JIT_NVCC',) + jitify = result.jitify + module = cp._core.core.compile_with_cache( + source=result.code, + options=options, + backend=backend, + jitify=jitify, + ) + kern = module.get_function(fname) + kernel._cache[(in_types, device_id)] = (kern, enable_cg) + + return kern + + +def get_jit_kernel_xla(kernel, c, *ins, outs): + # get the input types + in_types = [] + for x in ins: + x = c.get_shape(x) + in_types.append(_to_cupy_array_or_scalar(x.element_type(), len(x.dimensions()))) + for x in outs: + in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) + in_types = tuple(in_types) + # compile the kernel + return _compile_kernel_xla(kernel, in_types) + + +def get_jit_kernel_mlir(kernel, c): + # get the input types + in_types = [] + for x in c.avals_in: + in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) + for x in c.avals_out: + in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim)) + in_types = tuple(in_types) + # compile the kernel + return _compile_kernel_xla(kernel, in_types) + + +def _cupy_jit_kernel_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): + kernel_func = get_jit_kernel_xla(kernel, c, *ins, outs=kwargs['outs']) + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs']) + + # create custom call + return xla_client.ops.CustomCallWithLayout( + c, + b'cupy_kernel_call_gpu', + operands=ins, + operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), + shape_with_layout=xla_client.Shape.tuple_shape( + [xla_client.Shape.array_shape(value.dtype, value.shape, _shape_to_layout(value.shape)) + for value in kwargs['outs']] + ), + opaque=opaque, + ) + + +def register_cupy_jit_kernel_xla_gpu_translation_rule(primitive, gpu_kernel): + xla.backend_specific_translations['gpu'][primitive] = partial(_cupy_jit_kernel_xla_gpu_translation_rule, gpu_kernel) + + +def _cupy_jit_kernel_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): + kernel_func = get_jit_kernel_mlir(kernel, c) + grid = kwargs.get('grid', None) + block = kwargs.get('block', None) + shared_mem = kwargs.get('shared_mem', 0) + if grid is None or block is None: + raise ValueError('The grid and block should be specified for the cupy kernel.') + + # preprocess + import_brainpylib_gpu_ops() + opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.ptr, shared_mem, *ins, outs=kwargs['outs']) + + input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] + result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] + output_layouts = [_shape_to_layout(a.shape) for a in c.avals_out] + + return custom_call( + call_target_name='cupy_kernel_call_gpu', + operands=ins, + operand_layouts=list(input_layouts), + result_layouts=list(output_layouts), + result_types=list(result_types), + backend_config=opaque, + has_side_effect=False, + ).results + + +def register_cupy_jit_kernel_mlir_gpu_translation_rule(primitive, gpu_kernel): + if cp is None: + raise PackageMissingError("cupy", 'register cupy mlir gpu translation rule') + + rule = partial(_cupy_jit_kernel_mlir_gpu_translation_rule, gpu_kernel) + mlir.register_lowering(primitive, rule, platform='gpu') diff --git a/brainpy/_src/math/op_register/tests/test_cupy_based.py b/brainpy/_src/math/op_register/tests/test_cupy_based.py new file mode 100644 index 00000000..772b6160 --- /dev/null +++ b/brainpy/_src/math/op_register/tests/test_cupy_based.py @@ -0,0 +1,79 @@ +import jax +import pytest + +import brainpy.math as bm +from brainpy._src.dependency_check import import_cupy, import_cupy_jit, import_taichi + +cp = import_cupy(error_if_not_found=False) +cp_jit = import_cupy_jit(error_if_not_found=False) +ti = import_taichi(error_if_not_found=False) +if cp is None or ti is None: + pytest.skip('no cupy or taichi', allow_module_level=True) +bm.set_platform('cpu') + + +def test_cupy_based(): + bm.op_register.clear_taichi_aot_caches() + # Raw Module + + @ti.kernel + def simpleAdd(x1: ti.types.ndarray(ndim=2), + x2: ti.types.ndarray(ndim=2), + n: ti.types.ndarray(ndim=0), + y: ti.types.ndarray(ndim=2)): + for i, j in y: + y[i, j] = x1[i, j] + x2[i, j] + + source_code = r''' + extern "C"{ + + __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y) + { + unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < N) + { + y[tid] = x1[tid] + x2[tid]; + } + } + } + ''' + N = 10 + x1 = bm.ones((N, N)) + x2 = bm.ones((N, N)) + + mod = cp.RawModule(code=source_code) + kernel = mod.get_function('kernel') + + prim1 = bm.XLACustomOp(cpu_kernel=simpleAdd, gpu_kernel=kernel) + + y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0] + + print(y) + assert bm.allclose(y, x1 + x2) + + # JIT Kernel + @ti.kernel + def elementwise_copy_taichi(x: ti.types.ndarray(ndim=1), + size: ti.types.ndarray(ndim=1), + y: ti.types.ndarray(ndim=1)): + for i in y: + y[i] = x[i] + + @cp_jit.rawkernel() + def elementwise_copy(x, size, y): + tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x + ntid = cp_jit.gridDim.x * cp_jit.blockDim.x + for i in range(tid, size, ntid): + y[i] = x[i] + + size = 100 + x = bm.ones((size,)) + + prim2 = bm.XLACustomOp(cpu_kernel=elementwise_copy_taichi, gpu_kernel=elementwise_copy) + + y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0] + + print(y) + assert bm.allclose(y, x) + +# test_cupy_based() diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 199dce98..ea6dcadc 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -11,10 +11,9 @@ bm.set_platform('cpu') - @ti.func -def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: - return weight[0] +def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32: + return weight[None] @ti.func @@ -25,7 +24,7 @@ def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f @ti.kernel def event_ell_cpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=0), out: ti.types.ndarray(ndim=1)): weight_val = get_weight(weight) num_rows, num_cols = indices.shape @@ -35,11 +34,10 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) - @ti.kernel def event_ell_gpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=0), out: ti.types.ndarray(ndim=1)): weight_val = get_weight(weight) num_rows, num_cols = indices.shape @@ -48,21 +46,18 @@ def event_ell_gpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) - prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu) def test_taichi_op_register(): s = 1000 - indices = bm.random.randint(0, s, (s, 100)) + indices = bm.random.randint(0, s, (s, 1000)) vector = bm.random.rand(s) < 0.1 - weight = bm.array([1.0]) - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) print(out) - bm.clear_buffer_memory() # test_taichi_op_register() diff --git a/docs/tutorial_advanced/operator_custom_with_cupy.ipynb b/docs/tutorial_advanced/operator_custom_with_cupy.ipynb new file mode 100644 index 00000000..0b4bf241 --- /dev/null +++ b/docs/tutorial_advanced/operator_custom_with_cupy.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CPU and GPU Operator Customization with CuPy\n", + "\n", + "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)\n", + "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This functionality is only available for ``brainpylib>=0.3.1``. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## English Version" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Although we can now use the flexible taichi custom operator approach, taichi on cuda does not have more fine-grained control or optimization for some scenarios. So for such scenarios, we can use cupy's \n", + "- `RawModule`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n", + "- `jit.rawkernel`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n", + "\n", + "to compile and run CUDA native code directly as strings or cupy JIT function in real time for finer grained control.\n", + "\n", + "Start by importing the relevant Python package." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import brainpy.math as bm\n", + "\n", + "import jax\n", + "import cupy as cp\n", + "from cupyx import jit\n", + "\n", + "bm.set_platform('gpu')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuPy RawModule\n", + "\n", + "For dealing a large raw CUDA source or loading an existing CUDA binary, the RawModule class can be more handy. It can be initialized either by a CUDA source code. The needed kernels can then be retrieved by calling the get_function() method, which returns a RawKernel instance that can be invoked as discussed above.\n", + "\n", + "Be aware that the order of parameters in the kernel function you want to call should **keep outputs at the end of the parameter list**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source_code = r'''\n", + " extern \"C\"{\n", + "\n", + " __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)\n", + " {\n", + " unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;\n", + " if (tid < N)\n", + " {\n", + " y[tid] = x1[tid] + x2[tid];\n", + " }\n", + " }\n", + " }\n", + "'''\n", + "mod = cp.RawModule(code=source_code)\n", + "kernel = mod.get_function('kernel')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After define the `RawModule` and get the kernel function. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare inputs\n", + "N = 10\n", + "x1 = bm.ones((N, N))\n", + "x2 = bm.ones((N, N))\n", + "\n", + "# register the kernel as a custom op\n", + "prim1 = bm.XLACustomOp(gpu_kernel=kernel)\n", + "\n", + "# call the custom op\n", + "y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuPy JIT RawKernel\n", + "The cupyx.jit.rawkernel decorator can create raw CUDA kernels from Python functions.\n", + "\n", + "In this section, a Python function wrapped with the decorator is called a target function.\n", + "\n", + "Here is a short example for how to write a cupyx.jit.rawkernel to copy the values from x to y using a grid-stride loop:\n", + "\n", + "Launching a CUDA kernel on a GPU with pre-determined grid/block sizes requires basic understanding in the CUDA Programming Model. And the compilation will be deferred until the first function call. CuPy’s JIT compiler infers the types of arguments at the call time, and will cache the compiled kernels for speeding up any subsequent calls." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jit.rawkernel()\n", + "def elementwise_copy(x, size, y):\n", + " tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x\n", + " ntid = jit.gridDim.x * jit.blockDim.x\n", + " for i in range(tid, size, ntid):\n", + " y[i] = x[i]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After define the `jit.rawkernel`. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare inputs\n", + "size = 100\n", + "x = bm.ones((size,))\n", + "\n", + "# register the kernel as a custom op\n", + "prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)\n", + "\n", + "# call the custom op\n", + "y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb index 3c2667df..4b86a426 100644 --- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb @@ -99,8 +99,8 @@ "\n", "```python\n", "@ti.func\n", - "def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:\n", - " return weight[0]\n", + "def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:\n", + " return weight[None]\n", "\n", "@ti.func\n", "def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):\n", @@ -109,7 +109,7 @@ "@ti.kernel\n", "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1),\n", - " weight: ti.types.ndarray(ndim=1),\n", + " weight: ti.types.ndarray(ndim=0),\n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -207,8 +207,8 @@ "bm.set_platform('cpu')\n", "\n", "@ti.func\n", - "def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:\n", - " return weight[0]\n", + "def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:\n", + " return weight[None]\n", "\n", "\n", "@ti.func\n", @@ -219,7 +219,7 @@ "@ti.kernel\n", "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1),\n", - " weight: ti.types.ndarray(ndim=1),\n", + " weight: ti.types.ndarray(ndim=0),\n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -232,7 +232,7 @@ "@ti.kernel\n", "def event_ell_gpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1), \n", - " weight: ti.types.ndarray(ndim=1), \n", + " weight: ti.types.ndarray(ndim=0), \n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -248,11 +248,10 @@ " s = 1000\n", " indices = bm.random.randint(0, s, (s, 1000))\n", " vector = bm.random.rand(s) < 0.1\n", - " weight = bm.array([1.0])\n", "\n", - " out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", + " out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", "\n", - " out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", + " out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", "\n", " print(out)\n", "\n",