diff --git a/brainpy/_src/math/jitconn/event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py index 27998038..a22aac75 100644 --- a/brainpy/_src/math/jitconn/event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -1157,4 +1157,4 @@ def _define_event_mv_prob_normal_prim(cpu_kernel, gpu_kernel): _event_mv_prob_normal_p = _define_event_mv_prob_normal_prim( cpu_kernel=_event_mv_prob_normal_cpu, gpu_kernel=_event_mv_prob_normal_gpu - ) + ) \ No newline at end of file diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 595460ea..2a8cb3b6 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -324,11 +324,12 @@ def _preprocess_kernel_call_gpu( kernel_path = os.path.join(kernels_aot_path, source_md5_encode) # other args + param_total_num = len(ins) + len(outs) in_out_num = [len(ins), len(outs)] - in_out_type_list = [0] * 8 - in_out_dim_count_list = [0] * 8 - in_out_elem_count_list = [0] * 8 - in_out_shape_list = [0] * 64 + in_out_type_list = [0] * param_total_num + in_out_dim_count_list = [0] * param_total_num + in_out_elem_count_list = [0] * param_total_num + in_out_shape_list = [0] * param_total_num * 8 for i, value in enumerate(ins.values()): in_out_type_list[i] = type_number_map[value[0]]