Skip to content

Commit

Permalink
Update operator_custom_with_numba.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 14, 2024
1 parent 78c239c commit 48023d2
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions docs/tutorial_advanced/operator_custom_with_numba.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@
" return c, d\n",
"\n",
"def con_compute2(outs, ins):\n",
" c, d = outs # 取出所有的输出\n",
" a, b = ins # 取出所有的输入\n",
" c = outs[0] # take out all the outputs\n",
" d = outs[1]\n",
" a = ins[0] # take out all the inputs\n",
" b = ins[1]\n",
" c[:] = a + 1\n",
" d[:] = a * 2\n",
"\n",
Expand Down Expand Up @@ -191,7 +193,8 @@
"\n",
"def con_compute3(outs, ins):\n",
" c = outs # Take out all the outputs\n",
" a, b = ins # Take out all inputs\n",
" a = ins[0] # Take out all inputs\n",
" b = ins[1]\n",
" c[:] = 2.\n",
"\n",
"op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",
Expand Down Expand Up @@ -434,8 +437,10 @@
" return c, d # 返回多个抽象数组信息\n",
"\n",
"def con_compute2(outs, ins):\n",
" c, d = outs # 取出所有的输出\n",
" a, b = ins # 取出所有的输入\n",
" c = outs[0] # 取出所有的输出\n",
" d = outs[1]\n",
" a = ins[0] # 取出所有的输入\n",
" b = ins[1]\n",
" c[:] = a + 1\n",
" d[:] = a * 2\n",
"\n",
Expand Down Expand Up @@ -476,7 +481,8 @@
"\n",
"def con_compute3(outs, ins):\n",
" c = outs # 取出所有的输出\n",
" a, b = ins # 取出所有的输入\n",
" a = ins[0] # 取出所有的输入\n",
" b = ins[1]\n",
" c[:] = 2.\n",
"\n",
"op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",
Expand Down

0 comments on commit 48023d2

Please sign in to comment.