diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py index e1bed7de..091468c9 100644 --- a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py +++ b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py @@ -31,14 +31,14 @@ def eval_shape2(a, b): d = ShapedArray(b.shape, dtype=b.dtype) return c, d -@numba.njit(parallel=True) + def con_compute2(outs, ins): - c = outs[0] # take out all the outputs - d = outs[1] - a = ins[0] # take out all the inputs - b = ins[1] - # c, d = outs - # a, b = ins + # c = outs[0] # take out all the outputs + # d = outs[1] + # a = ins[0] # take out all the inputs + # b = ins[1] + c, d = outs + a, b = ins c[:] = a + 1 d[:] = b * 2 diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index e4f8dd20..0b840db0 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -149,10 +149,8 @@ " return c, d\n", "\n", "def con_compute2(outs, ins):\n", - " c = outs[0] # take out all the outputs\n", - " d = outs[1]\n", - " a = ins[0] # take out all the inputs\n", - " b = ins[1]\n", + " c, d = outs # take out all the outputs\n", + " a, b = ins # take out all the inputs\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -193,8 +191,7 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # Take out all the outputs\n", - " a = ins[0] # Take out all inputs\n", - " b = ins[1]\n", + " a, b = ins # Take out all inputs\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -437,10 +434,8 @@ " return c, d # 返回多个抽象数组信息\n", "\n", "def con_compute2(outs, ins):\n", - " c = outs[0] # 取出所有的输出\n", - " d = outs[1]\n", - " a = ins[0] # 取出所有的输入\n", - " b = ins[1]\n", + " c, d = outs # 取出所有的输出\n", + " a, b = ins # 取出所有的输入\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -481,8 +476,7 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # 取出所有的输出\n", - " a = ins[0] # 取出所有的输入\n", - " b = ins[1]\n", + " a, b = ins # 取出所有的输入\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",