diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 2a8cb3b6..858f338b 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -316,11 +316,11 @@ def _preprocess_kernel_call_cpu( def _preprocess_kernel_call_gpu( source_md5_encode: str, - ins: dict, - outs: dict, + ins: Sequence, + outs: Sequence, ) -> bytes: - if len(ins) + len(outs) > 8: - raise ValueError('The number of ins and outs must be less than 8!') + # if len(ins) + len(outs) > 8: + # raise ValueError('The number of ins and outs must be less than 8!') kernel_path = os.path.join(kernels_aot_path, source_md5_encode) # other args @@ -331,18 +331,18 @@ def _preprocess_kernel_call_gpu( in_out_elem_count_list = [0] * param_total_num in_out_shape_list = [0] * param_total_num * 8 - for i, value in enumerate(ins.values()): - in_out_type_list[i] = type_number_map[value[0]] - in_out_dim_count_list[i] = len(value[1]) - in_out_elem_count_list[i] = reduce(lambda x, y: x * y, value[1]) - for j, dim in enumerate(value[1]): + for i, value in enumerate(ins): + in_out_type_list[i] = type_number_map[value.dtype] + in_out_dim_count_list[i] = value.ndim + in_out_elem_count_list[i] = value.size + for j, dim in enumerate(value.shape): in_out_shape_list[i * 8 + j] = dim - for i, value in enumerate(outs.values()): - in_out_type_list[i + len(ins)] = type_number_map[value[0]] - in_out_dim_count_list[i + len(ins)] = len(value[1]) - in_out_elem_count_list[i + len(ins)] = reduce(lambda x, y: x * y, value[1]) - for j, dim in enumerate(value[1]): + for i, value in enumerate(outs): + in_out_type_list[i + len(ins)] = type_number_map[value.dtype] + in_out_dim_count_list[i + len(ins)] = value.ndim + in_out_elem_count_list[i + len(ins)] = value.size + for j, dim in enumerate(value.shape): in_out_shape_list[(i + len(ins)) * 8 + j] = dim # covert to string @@ -407,7 +407,7 @@ def _compile_kernel(abs_ins, kernel, platform: str, **kwargs): # returns if platform in ['gpu', 'cuda']: import_brainpylib_gpu_ops() - opaque = _preprocess_kernel_call_gpu(source_md5_encode, ins_dict, outs_dict) + opaque = _preprocess_kernel_call_gpu(source_md5_encode, abs_ins, abs_outs) return opaque elif platform == 'cpu': import_brainpylib_cpu_ops()