Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 19, 2024
1 parent c8be3ee commit 0cca5df
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 79 deletions.
34 changes: 5 additions & 29 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,34 +46,10 @@ def dtype(self) -> np.dtype:
class XLACustomOp(BrainPyObject):
"""Creating a XLA custom call operator.
>>> import numba as nb
>>> import taichi as ti
>>> import numpy as np
>>> import jax
>>>
>>> @nb.njit
>>> def numba_cpu_fun(a, b, out_a, out_b):
>>> out_a[:] = a
>>> out_b[:] = b
>>>
>>> @ti.kernel
>>> def taichi_gpu_fun(a, b, out_a, out_b):
>>> for i in range(a.size):
>>> out_a[i] = a[i]
>>> for i in range(b.size):
>>> out_b[i] = b[i]
>>>
>>> # option 1
>>> prim = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun)
>>> a2, b2 = prim(np.random.random(1000), np.random.random(1000),
>>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32),
>>> jax.ShapeDtypeStruct(1000, dtype=np.float32)])
>>>
>>> # option 2
>>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun,
>>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype),
>>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)])
>>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000))
For more information, please refer to the tutorials above:
Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html
Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html
CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html
Args:
cpu_kernel: Callable. The function defines the computation on CPU backend.
Expand Down Expand Up @@ -130,7 +106,7 @@ def __init__(
gpu_checked = False
if gpu_kernel is None:
gpu_checked = True
elif isinstance(gpu_kernel, str): # cupy RawModule
elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule
register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel
Expand Down
18 changes: 2 additions & 16 deletions brainpy/_src/math/op_register/cupy_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,9 @@ def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
if grid is None or block is None:
raise ValueError('The grid and block should be specified for the cupy kernel.')

# compile
mod = cp.RawModule(code=kernel)
try:
kernel_func = mod.get_function('kernel')
except AttributeError:
raise ValueError('The \'kernel\' function is not found in the module.')

# preprocess
import_brainpylib_gpu_ops()
opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])
opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])

# create custom call
return xla_client.ops.CustomCallWithLayout(
Expand All @@ -116,16 +109,9 @@ def _cupy_raw_module_mlir_gpu_translation_rule(kernel, c, *ins, **kwargs):
if grid is None or block is None:
raise ValueError('The grid and block should be specified for the cupy kernel.')

# compile
mod = cp.RawModule(code=kernel)
try:
kernel_func = mod.get_function('kernel')
except AttributeError:
raise ValueError('The \'kernel\' function is not found in the module.')

# preprocess
import_brainpylib_gpu_ops()
opaque = _preprocess_kernel_call_gpu(grid, block, kernel_func.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])
opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])

input_layouts = [_shape_to_layout(a.shape) for a in c.avals_in]
result_types = [mlir.aval_to_ir_type(out) for out in c.avals_out]
Expand Down
42 changes: 30 additions & 12 deletions brainpy/_src/math/op_register/tests/test_cupy_based.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
import jax
import jax.numpy as jnp
import pytest

import brainpy.math as bm
from brainpy._src.dependency_check import import_cupy, import_cupy_jit
from brainpy._src.dependency_check import import_cupy, import_cupy_jit, import_taichi

cp = import_cupy(error_if_not_found=False)
cp_jit = import_cupy_jit(error_if_not_found=False)
if cp is None:
pytest.skip('no cupy', allow_module_level=True)
bm.set_platform('gpu')
ti = import_taichi(error_if_not_found=False)
if cp is None or ti is None:
pytest.skip('no cupy or taichi', allow_module_level=True)
bm.set_platform('cpu')


def test_cupy_based():
bm.op_register.clear_taichi_aot_caches()
# Raw Module

@ti.kernel
def simpleAdd(x1: ti.types.ndarray(ndim=2),
x2: ti.types.ndarray(ndim=2),
n: ti.types.ndarray(ndim=0),
y: ti.types.ndarray(ndim=2)):
for i, j in y:
y[i, j] = x1[i, j] + x2[i, j]

source_code = r'''
extern "C"{
Expand All @@ -31,16 +40,25 @@ def test_cupy_based():
N = 10
x1 = bm.ones((N, N))
x2 = bm.ones((N, N))
prim1 = bm.XLACustomOp(gpu_kernel=source_code)

# n = jnp.asarray([N**2,], dtype=jnp.int32)
mod = cp.RawModule(code=source_code)
kernel = mod.get_function('kernel')

prim1 = bm.XLACustomOp(cpu_kernel=simpleAdd, gpu_kernel=kernel)

y = prim1(x1, x2, N ** 2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=jnp.float32)])[0]
y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]

print(y)
assert jnp.allclose(y, x1 + x2)
assert bm.allclose(y, x1 + x2)

# JIT Kernel
@ti.kernel
def elementwise_copy_taichi(x: ti.types.ndarray(ndim=1),
size: ti.types.ndarray(ndim=1),
y: ti.types.ndarray(ndim=1)):
for i in y:
y[i] = x[i]

@cp_jit.rawkernel()
def elementwise_copy(x, size, y):
tid = cp_jit.blockIdx.x * cp_jit.blockDim.x + cp_jit.threadIdx.x
Expand All @@ -51,11 +69,11 @@ def elementwise_copy(x, size, y):
size = 100
x = bm.ones((size,))

prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)
prim2 = bm.XLACustomOp(cpu_kernel=elementwise_copy_taichi, gpu_kernel=elementwise_copy)

y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=jnp.float32)])[0]
y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]

print(y)
assert jnp.allclose(y, x)
assert bm.allclose(y, x)

# test_cupy_based()
19 changes: 7 additions & 12 deletions brainpy/_src/math/op_register/tests/test_taichi_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@

bm.set_platform('cpu')


@ti.func
def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
return weight[0]
def get_weight(weight: ti.types.ndarray(ndim=0)) -> ti.f32:
return weight[None]


@ti.func
Expand All @@ -25,7 +24,7 @@ def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f
@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
vector: ti.types.ndarray(ndim=1),
weight: ti.types.ndarray(ndim=1),
weight: ti.types.ndarray(ndim=0),
out: ti.types.ndarray(ndim=1)):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
Expand All @@ -35,11 +34,10 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)


@ti.kernel
def event_ell_gpu(indices: ti.types.ndarray(ndim=2),
vector: ti.types.ndarray(ndim=1),
weight: ti.types.ndarray(ndim=1),
weight: ti.types.ndarray(ndim=0),
out: ti.types.ndarray(ndim=1)):
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
Expand All @@ -48,21 +46,18 @@ def event_ell_gpu(indices: ti.types.ndarray(ndim=2),
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)


prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu, gpu_kernel=event_ell_gpu)


def test_taichi_op_register():
s = 1000
indices = bm.random.randint(0, s, (s, 100))
indices = bm.random.randint(0, s, (s, 1000))
vector = bm.random.rand(s) < 0.1
weight = bm.array([1.0])

out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])
out = prim(indices, vector, 1.0, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

print(out)
bm.clear_buffer_memory()

# test_taichi_op_register()
174 changes: 174 additions & 0 deletions docs/tutorial_advanced/operator_custom_with_cupy.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# CPU and GPU Operator Customization with CuPy\n",
"\n",
"[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)\n",
"[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs/tutorial_advanced/operator_custom_with_cupy.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This functionality is only available for ``brainpylib>=0.3.1``. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## English Version"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Although we can now use the flexible taichi custom operator approach, taichi on cuda does not have more fine-grained control or optimization for some scenarios. So for such scenarios, we can use cupy's \n",
"- `RawModule`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n",
"- `jit.rawkernel`(https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n",
"\n",
"to compile and run CUDA native code directly as strings or cupy JIT function in real time for finer grained control.\n",
"\n",
"Start by importing the relevant Python package."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import brainpy.math as bm\n",
"\n",
"import jax\n",
"import cupy as cp\n",
"from cupyx import jit\n",
"\n",
"bm.set_platform('gpu')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CuPy RawModule\n",
"\n",
"For dealing a large raw CUDA source or loading an existing CUDA binary, the RawModule class can be more handy. It can be initialized either by a CUDA source code. The needed kernels can then be retrieved by calling the get_function() method, which returns a RawKernel instance that can be invoked as discussed above.\n",
"\n",
"Be aware that the order of parameters in the kernel function you want to call should **keep outputs at the end of the parameter list**."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"source_code = r'''\n",
" extern \"C\"{\n",
"\n",
" __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)\n",
" {\n",
" unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;\n",
" if (tid < N)\n",
" {\n",
" y[tid] = x1[tid] + x2[tid];\n",
" }\n",
" }\n",
" }\n",
"'''\n",
"mod = cp.RawModule(code=source_code)\n",
"kernel = mod.get_function('kernel')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After define the `RawModule` and get the kernel function. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# prepare inputs\n",
"N = 10\n",
"x1 = bm.ones((N, N))\n",
"x2 = bm.ones((N, N))\n",
"\n",
"# register the kernel as a custom op\n",
"prim1 = bm.XLACustomOp(gpu_kernel=kernel)\n",
"\n",
"# call the custom op\n",
"y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CuPy JIT RawKernel\n",
"The cupyx.jit.rawkernel decorator can create raw CUDA kernels from Python functions.\n",
"\n",
"In this section, a Python function wrapped with the decorator is called a target function.\n",
"\n",
"Here is a short example for how to write a cupyx.jit.rawkernel to copy the values from x to y using a grid-stride loop:\n",
"\n",
"Launching a CUDA kernel on a GPU with pre-determined grid/block sizes requires basic understanding in the CUDA Programming Model. And the compilation will be deferred until the first function call. CuPy’s JIT compiler infers the types of arguments at the call time, and will cache the compiled kernels for speeding up any subsequent calls."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@jit.rawkernel()\n",
"def elementwise_copy(x, size, y):\n",
" tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x\n",
" ntid = jit.gridDim.x * jit.blockDim.x\n",
" for i in range(tid, size, ntid):\n",
" y[i] = x[i]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"After define the `jit.rawkernel`. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# prepare inputs\n",
"size = 100\n",
"x = bm.ones((size,))\n",
"\n",
"# register the kernel as a custom op\n",
"prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)\n",
"\n",
"# call the custom op\n",
"y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit 0cca5df

Please sign in to comment.