From 0cca5df935192c413cd6a5931eaf33d996f393fa Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 19 Mar 2024 15:57:28 +0800 Subject: [PATCH] Update docs --- brainpy/_src/math/op_register/base.py | 34 +--- brainpy/_src/math/op_register/cupy_based.py | 18 +- .../math/op_register/tests/test_cupy_based.py | 42 +++-- .../op_register/tests/test_taichi_based.py | 19 +- .../operator_custom_with_cupy.ipynb | 174 ++++++++++++++++++ .../operator_custom_with_taichi.ipynb | 19 +- 6 files changed, 227 insertions(+), 79 deletions(-) create mode 100644 docs/tutorial_advanced/operator_custom_with_cupy.ipynb diff --git a/brainpy/_src/math/op_register/base.py b/brainpy/_src/math/op_register/base.py index 1aae5b8b..5af5a7e3 100644 --- a/brainpy/_src/math/op_register/base.py +++ b/brainpy/_src/math/op_register/base.py @@ -46,34 +46,10 @@ def dtype(self) -> np.dtype: class XLACustomOp(BrainPyObject): """Creating a XLA custom call operator. - >>> import numba as nb - >>> import taichi as ti - >>> import numpy as np - >>> import jax - >>> - >>> @nb.njit - >>> def numba_cpu_fun(a, b, out_a, out_b): - >>> out_a[:] = a - >>> out_b[:] = b - >>> - >>> @ti.kernel - >>> def taichi_gpu_fun(a, b, out_a, out_b): - >>> for i in range(a.size): - >>> out_a[i] = a[i] - >>> for i in range(b.size): - >>> out_b[i] = b[i] - >>> - >>> # option 1 - >>> prim = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun) - >>> a2, b2 = prim(np.random.random(1000), np.random.random(1000), - >>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32), - >>> jax.ShapeDtypeStruct(1000, dtype=np.float32)]) - >>> - >>> # option 2 - >>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun, - >>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype), - >>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)]) - >>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000)) + For more information, please refer to the tutorials above: + Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html + Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html + CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html Args: cpu_kernel: Callable. The function defines the computation on CPU backend. @@ -130,7 +106,7 @@ def __init__( gpu_checked = False if gpu_kernel is None: gpu_checked = True - elif isinstance(gpu_kernel, str): # cupy RawModule + elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel) gpu_checked = True elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel diff --git a/brainpy/_src/math/op_register/cupy_based.py b/brainpy/_src/math/op_register/cupy_based.py index ad566b9b..86545a98 100644 --- a/brainpy/_src/math/op_register/cupy_based.py +++ b/brainpy/_src/math/op_register/cupy_based.py @@ -80,16 +80,9 @@ def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs): if grid is None or block is None: raise ValueError('The grid and block should be specified for the cupy kernel.') - # compile - mod = cp.RawModule(code=kernel) - try: - kernel_func = mod.get_function('kernel') - except AttributeError: - raise ValueError('The \'kernel\' function is not found in the module.') - # preprocess import_brainpylib_gpu_ops() - opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) + opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) # create custom call return xla_client.ops.CustomCallWithLayout( @@ -116,16 +109,9 @@ def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs): if grid is None or block is None: raise ValueError('The grid and block should be specified for the cupy kernel.') - # compile - mod = cp.RawModule(code=kernel) - try: - kernel_func = mod.get_function('kernel') - except AttributeError: - raise ValueError('The \'kernel\' function is not found in the module.') - # preprocess import_brainpylib_gpu_ops() - opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) + opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs']) input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in] result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out] diff --git a/brainpy/_src/math/op_register/tests/test_cupy_based.py b/brainpy/_src/math/op_register/tests/test_cupy_based.py index 4bcc1323..772b6160 100644 --- a/brainpy/_src/math/op_register/tests/test_cupy_based.py +++ b/brainpy/_src/math/op_register/tests/test_cupy_based.py @@ -1,20 +1,29 @@ import jax -import jax.numpy as jnp import pytest import brainpy.math as bm -from brainpy._src.dependency_check import import_cupy, import_cupy_jit +from brainpy._src.dependency_check import import_cupy, import_cupy_jit, import_taichi cp = import_cupy(error_if_not_found=False) cp_jit = import_cupy_jit(error_if_not_found=False) -if cp is None: - pytest.skip('no cupy', allow_module_level=True) -bm.set_platform('gpu') +ti = import_taichi(error_if_not_found=False) +if cp is None or ti is None: + pytest.skip('no cupy or taichi', allow_module_level=True) +bm.set_platform('cpu') def test_cupy_based(): + bm.op_register.clear_taichi_aot_caches() # Raw Module + @ti.kernel + def simpleAdd(x1: ti.types.ndarray(ndim=2), + x2: ti.types.ndarray(ndim=2), + n: ti.types.ndarray(ndim=0), + y: ti.types.ndarray(ndim=2)): + for i, j in y: + y[i, j] = x1[i, j] + x2[i, j] + source_code = r''' extern "C"{ @@ -31,16 +40,25 @@ def test_cupy_based(): N = 10 x1 = bm.ones((N, N)) x2 = bm.ones((N, N)) - prim1 = bm.XLACustomOp(gpu_kernel=source_code) - # n = jnp.asarray([N**2,], dtype=jnp.int32) + mod = cp.RawModule(code=source_code) + kernel = mod.get_function('kernel') + + prim1 = bm.XLACustomOp(cpu_kernel=simpleAdd, gpu_kernel=kernel) - y = prim1(x1, x2, N ** 2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0] + y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0] print(y) - assert jnp.allclose(y, x1 + x2) + assert bm.allclose(y, x1 + x2) # JIT Kernel + @ti.kernel + def elementwise_copy_taichi(x: ti.types.ndarray(ndim=1), + size: ti.types.ndarray(ndim=1), + y: ti.types.ndarray(ndim=1)): + for i in y: + y[i] = x[i] + @cp_jit.rawkernel() def elementwise_copy(x, size, y): tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x @@ -51,11 +69,11 @@ def elementwise_copy(x, size, y): size = 100 x = bm.ones((size,)) - prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy) + prim2 = bm.XLACustomOp(cpu_kernel=elementwise_copy_taichi, gpu_kernel=elementwise_copy) - y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=jnp.float32)])[0] + y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0] print(y) - assert jnp.allclose(y, x) + assert bm.allclose(y, x) # test_cupy_based() diff --git a/brainpy/_src/math/op_register/tests/test_taichi_based.py b/brainpy/_src/math/op_register/tests/test_taichi_based.py index 199dce98..ea6dcadc 100644 --- a/brainpy/_src/math/op_register/tests/test_taichi_based.py +++ b/brainpy/_src/math/op_register/tests/test_taichi_based.py @@ -11,10 +11,9 @@ bm.set_platform('cpu') - @ti.func -def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: - return weight[0] +def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32: + return weight[None] @ti.func @@ -25,7 +24,7 @@ def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f @ti.kernel def event_ell_cpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=0), out: ti.types.ndarray(ndim=1)): weight_val = get_weight(weight) num_rows, num_cols = indices.shape @@ -35,11 +34,10 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) - @ti.kernel def event_ell_gpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), - weight: ti.types.ndarray(ndim=1), + weight: ti.types.ndarray(ndim=0), out: ti.types.ndarray(ndim=1)): weight_val = get_weight(weight) num_rows, num_cols = indices.shape @@ -48,21 +46,18 @@ def event_ell_gpu(indices: ti.types.ndarray(ndim=2), for j in range(num_cols): update_output(out, indices[i, j], weight_val) - prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu) def test_taichi_op_register(): s = 1000 - indices = bm.random.randint(0, s, (s, 100)) + indices = bm.random.randint(0, s, (s, 1000)) vector = bm.random.rand(s) < 0.1 - weight = bm.array([1.0]) - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) - out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) + out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) print(out) - bm.clear_buffer_memory() # test_taichi_op_register() diff --git a/docs/tutorial_advanced/operator_custom_with_cupy.ipynb b/docs/tutorial_advanced/operator_custom_with_cupy.ipynb new file mode 100644 index 00000000..0b4bf241 --- /dev/null +++ b/docs/tutorial_advanced/operator_custom_with_cupy.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CPU and GPU Operator Customization with CuPy\n", + "\n", + "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)\n", + "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This functionality is only available for ``brainpylib>=0.3.1``. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## English Version" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Although we can now use the flexible taichi custom operator approach, taichi on cuda does not have more fine-grained control or optimization for some scenarios. So for such scenarios, we can use cupy's \n", + "- `RawModule`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n", + "- `jit.rawkernel`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n", + "\n", + "to compile and run CUDA native code directly as strings or cupy JIT function in real time for finer grained control.\n", + "\n", + "Start by importing the relevant Python package." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import brainpy.math as bm\n", + "\n", + "import jax\n", + "import cupy as cp\n", + "from cupyx import jit\n", + "\n", + "bm.set_platform('gpu')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuPy RawModule\n", + "\n", + "For dealing a large raw CUDA source or loading an existing CUDA binary, the RawModule class can be more handy. It can be initialized either by a CUDA source code. The needed kernels can then be retrieved by calling the get_function() method, which returns a RawKernel instance that can be invoked as discussed above.\n", + "\n", + "Be aware that the order of parameters in the kernel function you want to call should **keep outputs at the end of the parameter list**." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source_code = r'''\n", + " extern \"C\"{\n", + "\n", + " __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)\n", + " {\n", + " unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;\n", + " if (tid < N)\n", + " {\n", + " y[tid] = x1[tid] + x2[tid];\n", + " }\n", + " }\n", + " }\n", + "'''\n", + "mod = cp.RawModule(code=source_code)\n", + "kernel = mod.get_function('kernel')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After define the `RawModule` and get the kernel function. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare inputs\n", + "N = 10\n", + "x1 = bm.ones((N, N))\n", + "x2 = bm.ones((N, N))\n", + "\n", + "# register the kernel as a custom op\n", + "prim1 = bm.XLACustomOp(gpu_kernel=kernel)\n", + "\n", + "# call the custom op\n", + "y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### CuPy JIT RawKernel\n", + "The cupyx.jit.rawkernel decorator can create raw CUDA kernels from Python functions.\n", + "\n", + "In this section, a Python function wrapped with the decorator is called a target function.\n", + "\n", + "Here is a short example for how to write a cupyx.jit.rawkernel to copy the values from x to y using a grid-stride loop:\n", + "\n", + "Launching a CUDA kernel on a GPU with pre-determined grid/block sizes requires basic understanding in the CUDA Programming Model. And the compilation will be deferred until the first function call. CuPy’s JIT compiler infers the types of arguments at the call time, and will cache the compiled kernels for speeding up any subsequent calls." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "@jit.rawkernel()\n", + "def elementwise_copy(x, size, y):\n", + " tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x\n", + " ntid = jit.gridDim.x * jit.blockDim.x\n", + " for i in range(tid, size, ntid):\n", + " y[i] = x[i]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After define the `jit.rawkernel`. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# prepare inputs\n", + "size = 100\n", + "x = bm.ones((size,))\n", + "\n", + "# register the kernel as a custom op\n", + "prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)\n", + "\n", + "# call the custom op\n", + "y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb index 3c2667df..4b86a426 100644 --- a/docs/tutorial_advanced/operator_custom_with_taichi.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_taichi.ipynb @@ -99,8 +99,8 @@ "\n", "```python\n", "@ti.func\n", - "def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:\n", - " return weight[0]\n", + "def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:\n", + " return weight[None]\n", "\n", "@ti.func\n", "def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):\n", @@ -109,7 +109,7 @@ "@ti.kernel\n", "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1),\n", - " weight: ti.types.ndarray(ndim=1),\n", + " weight: ti.types.ndarray(ndim=0),\n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -207,8 +207,8 @@ "bm.set_platform('cpu')\n", "\n", "@ti.func\n", - "def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:\n", - " return weight[0]\n", + "def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:\n", + " return weight[None]\n", "\n", "\n", "@ti.func\n", @@ -219,7 +219,7 @@ "@ti.kernel\n", "def event_ell_cpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1),\n", - " weight: ti.types.ndarray(ndim=1),\n", + " weight: ti.types.ndarray(ndim=0),\n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -232,7 +232,7 @@ "@ti.kernel\n", "def event_ell_gpu(indices: ti.types.ndarray(ndim=2),\n", " vector: ti.types.ndarray(ndim=1), \n", - " weight: ti.types.ndarray(ndim=1), \n", + " weight: ti.types.ndarray(ndim=0), \n", " out: ti.types.ndarray(ndim=1)):\n", " weight_val = get_weight(weight)\n", " num_rows, num_cols = indices.shape\n", @@ -248,11 +248,10 @@ " s = 1000\n", " indices = bm.random.randint(0, s, (s, 1000))\n", " vector = bm.random.rand(s) < 0.1\n", - " weight = bm.array([1.0])\n", "\n", - " out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", + " out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", "\n", - " out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", + " out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])\n", "\n", " print(out)\n", "\n",