Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 15, 2024
1 parent c70b3dd commit a9a9845
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def eval_shape2(a, b):
d = ShapedArray(b.shape, dtype=b.dtype)
return c, d

@numba.njit(parallel=True)

def con_compute2(outs, ins):
c = outs[0] # take out all the outputs
d = outs[1]
a = ins[0] # take out all the inputs
b = ins[1]
# c, d = outs
# a, b = ins
# c = outs[0] # take out all the outputs
# d = outs[1]
# a = ins[0] # take out all the inputs
# b = ins[1]
c, d = outs
a, b = ins
c[:] = a + 1
d[:] = b * 2

Expand Down
18 changes: 6 additions & 12 deletions docs/tutorial_advanced/operator_custom_with_numba.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,8 @@
" return c, d\n",
"\n",
"def con_compute2(outs, ins):\n",
" c = outs[0] # take out all the outputs\n",
" d = outs[1]\n",
" a = ins[0] # take out all the inputs\n",
" b = ins[1]\n",
" c, d = outs # take out all the outputs\n",
" a, b = ins # take out all the inputs\n",
" c[:] = a + 1\n",
" d[:] = a * 2\n",
"\n",
Expand Down Expand Up @@ -193,8 +191,7 @@
"\n",
"def con_compute3(outs, ins):\n",
" c = outs # Take out all the outputs\n",
" a = ins[0] # Take out all inputs\n",
" b = ins[1]\n",
" a, b = ins # Take out all inputs\n",
" c[:] = 2.\n",
"\n",
"op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",
Expand Down Expand Up @@ -437,10 +434,8 @@
" return c, d # 返回多个抽象数组信息\n",
"\n",
"def con_compute2(outs, ins):\n",
" c = outs[0] # 取出所有的输出\n",
" d = outs[1]\n",
" a = ins[0] # 取出所有的输入\n",
" b = ins[1]\n",
" c, d = outs # 取出所有的输出\n",
" a, b = ins # 取出所有的输入\n",
" c[:] = a + 1\n",
" d[:] = a * 2\n",
"\n",
Expand Down Expand Up @@ -481,8 +476,7 @@
"\n",
"def con_compute3(outs, ins):\n",
" c = outs # 取出所有的输出\n",
" a = ins[0] # 取出所有的输入\n",
" b = ins[1]\n",
" a, b = ins # 取出所有的输入\n",
" c[:] = 2.\n",
"\n",
"op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n",
Expand Down

0 comments on commit a9a9845

Please sign in to comment.