Skip to content

Commit

Permalink
small update
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 20, 2024
1 parent 0cca5df commit 1849423
Showing 1 changed file with 41 additions and 63 deletions.
104 changes: 41 additions & 63 deletions brainpy/_src/math/op_register/cupy_based.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial, reduce
from typing import List
from typing import List, Tuple

import jax
import numpy as np
Expand Down Expand Up @@ -37,8 +37,8 @@


def _preprocess_kernel_call_gpu(
grid: int,
block: int,
grid: Tuple[int],
block: Tuple[int],
func_ptr: int,
shared_mem: int,
*ins,
Expand Down Expand Up @@ -82,6 +82,8 @@ def _cupy_raw_module_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):

# preprocess
import_brainpylib_gpu_ops()
# THE KEY:
# - using the kernel pointer at "kernel.kernel.ptr"
opaque = _preprocess_kernel_call_gpu(grid, block, kernel.kernel.ptr, shared_mem, *ins, outs=kwargs['outs'])

# create custom call
Expand Down Expand Up @@ -136,28 +138,28 @@ def register_cupy_raw_module_mlir_gpu_translation_rule(primitive, gpu_kernel):
mlir.register_lowering(primitive, rule, platform='gpu')


def get_jit_kernel_xla(kernel, c, *ins, outs):
# check if compiled
in_types = []
for x in ins:
x = c.get_shape(x)
if len(x.dimensions()) != 0:
t = cp_jit._cuda_types.CArray(dtype=x.element_type(), ndim=len(x.dimensions()), is_c_contiguous=True,
index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=x.element_type())
in_types.append(t)
for x in outs:
if x.ndim != 0:
t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=x.dtype)
in_types.append(t)
in_types = tuple(in_types)
def _to_cupy_array_or_scalar(dtype, ndim):
# THE KEY
# - using the cupy jit compiler to get the type
if ndim != 0:
t = cp_jit._cuda_types.CArray(dtype=dtype,
ndim=ndim,
is_c_contiguous=True,
index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=dtype)
return t


def _compile_kernel_xla(kernel, in_types):
# THE KEY
# - get the kernel function from the cache
device_id = cp.cuda.get_device_id()
kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None))

if kern is None:
# THE KEY:
# - compile the kernel function
result = kernel._cached_codes.get(in_types)
if result is None:
result = cp_jit._compile.transpile(
Expand Down Expand Up @@ -187,53 +189,29 @@ def get_jit_kernel_xla(kernel, c, *ins, outs):
return kern


def get_jit_kernel_xla(kernel, c, *ins, outs):
# get the input types
in_types = []
for x in ins:
x = c.get_shape(x)
in_types.append(_to_cupy_array_or_scalar(x.element_type(), len(x.dimensions())))
for x in outs:
in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
in_types = tuple(in_types)
# compile the kernel
return _compile_kernel_xla(kernel, in_types)


def get_jit_kernel_mlir(kernel, c):
# check if compiled
# get the input types
in_types = []
for x in c.avals_in:
if x.ndim != 0:
t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=x.dtype)
in_types.append(t)
in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
for x in c.avals_out:
if x.ndim != 0:
t = cp_jit._cuda_types.CArray(dtype=x.dtype, ndim=x.ndim, is_c_contiguous=True, index_32_bits=True)
else:
t = cp_jit._cuda_types.Scalar(dtype=x.dtype)
in_types.append(t)
in_types.append(_to_cupy_array_or_scalar(x.dtype, x.ndim))
in_types = tuple(in_types)
device_id = cp.cuda.get_device_id()
kern, enable_cg = kernel._cache.get((in_types, device_id), (None, None))

if kern is None:
result = kernel._cached_codes.get(in_types)
if result is None:
result = cp_jit._compile.transpile(
kernel._func,
['extern "C"', '__global__'],
'cuda',
in_types,
cp_jit._cuda_types.void,
)
kernel._cached_codes[in_types] = result
fname = result.func_name
enable_cg = result.enable_cooperative_groups
options = result.options
backend = result.backend
if backend == 'nvcc':
options += ('-DCUPY_JIT_NVCC',)
jitify = result.jitify
module = cp._core.core.compile_with_cache(
source=result.code,
options=options,
backend=backend,
jitify=jitify,
)
kern = module.get_function(fname)
kernel._cache[(in_types, device_id)] = (kern, enable_cg)

return kern
# compile the kernel
return _compile_kernel_xla(kernel, in_types)


def _cupy_jit_kernel_xla_gpu_translation_rule(kernel, c, *ins, **kwargs):
Expand Down

0 comments on commit 1849423

Please sign in to comment.