diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index e1121f5b..e4f8dd20 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -149,8 +149,10 @@ " return c, d\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " c = outs[0] # take out all the outputs\n", + " d = outs[1]\n", + " a = ins[0] # take out all the inputs\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -191,7 +193,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # Take out all the outputs\n", - " a, b = ins # Take out all inputs\n", + " a = ins[0] # Take out all inputs\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -434,8 +437,10 @@ " return c, d # 返回多个抽象数组信息\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " c = outs[0] # 取出所有的输出\n", + " d = outs[1]\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -476,7 +481,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",