Skip to content

Commit

Permalink
[math] Fix taichi custom operator on gpu backend
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Mar 22, 2024
1 parent 3866203 commit 6222464
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,11 @@ def _preprocess_kernel_call_cpu(

def _preprocess_kernel_call_gpu(
source_md5_encode: str,
ins: dict,
outs: dict,
ins: Sequence,
outs: Sequence,
) -> bytes:
if len(ins) + len(outs) > 8:
raise ValueError('The number of ins and outs must be less than 8!')
# if len(ins) + len(outs) > 8:
# raise ValueError('The number of ins and outs must be less than 8!')
kernel_path = os.path.join(kernels_aot_path, source_md5_encode)

# other args
Expand All @@ -331,18 +331,18 @@ def _preprocess_kernel_call_gpu(
in_out_elem_count_list = [0] * param_total_num
in_out_shape_list = [0] * param_total_num * 8

for i, value in enumerate(ins.values()):
in_out_type_list[i] = type_number_map[value[0]]
in_out_dim_count_list[i] = len(value[1])
in_out_elem_count_list[i] = reduce(lambda x, y: x * y, value[1])
for j, dim in enumerate(value[1]):
for i, value in enumerate(ins):
in_out_type_list[i] = type_number_map[value.dtype]
in_out_dim_count_list[i] = value.ndim
in_out_elem_count_list[i] = value.size
for j, dim in enumerate(value.shape):
in_out_shape_list[i * 8 + j] = dim

for i, value in enumerate(outs.values()):
in_out_type_list[i + len(ins)] = type_number_map[value[0]]
in_out_dim_count_list[i + len(ins)] = len(value[1])
in_out_elem_count_list[i + len(ins)] = reduce(lambda x, y: x * y, value[1])
for j, dim in enumerate(value[1]):
for i, value in enumerate(outs):
in_out_type_list[i + len(ins)] = type_number_map[value.dtype]
in_out_dim_count_list[i + len(ins)] = value.ndim
in_out_elem_count_list[i + len(ins)] = value.size
for j, dim in enumerate(value.shape):
in_out_shape_list[(i + len(ins)) * 8 + j] = dim

# covert to string
Expand Down Expand Up @@ -407,7 +407,7 @@ def _compile_kernel(abs_ins, kernel, platform: str, **kwargs):
# returns
if platform in ['gpu', 'cuda']:
import_brainpylib_gpu_ops()
opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict)
opaque = _preprocess_kernel_call_gpu(source_md5_encode, abs_ins, abs_outs)
return opaque
elif platform == 'cpu':
import_brainpylib_cpu_ops()
Expand Down

0 comments on commit 6222464

Please sign in to comment.