diff --git a/brainpy/_src/math/op_register/numba_approach/__init__.py b/brainpy/_src/math/op_register/numba_approach/__init__.py index 4d9b284f..35c9beef 100644 --- a/brainpy/_src/math/op_register/numba_approach/__init__.py +++ b/brainpy/_src/math/op_register/numba_approach/__init__.py @@ -276,7 +276,7 @@ def abs_eval_rule(*input_shapes, **info): cpu_translation_rule = partial(_numba_mlir_cpu_translation_rule, cpu_func, - True) + False) mlir.register_lowering(prim, cpu_translation_rule, platform='cpu') diff --git a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py index 091468c9..21099cb6 100644 --- a/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py +++ b/brainpy/_src/math/op_register/numba_approach/tests/test_numba_approach.py @@ -31,14 +31,13 @@ def eval_shape2(a, b): d = ShapedArray(b.shape, dtype=b.dtype) return c, d - def con_compute2(outs, ins): - # c = outs[0] # take out all the outputs - # d = outs[1] - # a = ins[0] # take out all the inputs - # b = ins[1] - c, d = outs - a, b = ins + c = outs[0] # take out all the outputs + d = outs[1] + a = ins[0] # take out all the inputs + b = ins[1] + # c, d = outs + # a, b = ins c[:] = a + 1 d[:] = b * 2 diff --git a/docs/tutorial_advanced/operator_custom_with_numba.ipynb b/docs/tutorial_advanced/operator_custom_with_numba.ipynb index 0b840db0..7f00cd56 100644 --- a/docs/tutorial_advanced/operator_custom_with_numba.ipynb +++ b/docs/tutorial_advanced/operator_custom_with_numba.ipynb @@ -65,8 +65,6 @@ "source": [ "### ``brainpy.math.CustomOpByNumba``\n", "\n", - "``brainpy.math.CustomOpByNumba`` is also called ``brainpy.math.XLACustomOp``.\n", - "\n", "BrainPy provides ``brainpy.math.CustomOpByNumba`` for customizing the operator on the CPU device. Two parameters are required to provide in ``CustomOpByNumba``:\n", "\n", "- ``eval_shape``: evaluates the *shape* and *datatype* of the output argument based on the *shape* and *datatype* of the input argument.\n", @@ -137,7 +135,7 @@ "collapsed": false }, "source": [ - "### Return multiple values ``multiple_returns=True``\n", + "#### Return multiple values ``multiple_returns=True``\n", "\n", "If the result of our computation needs to return multiple arrays, then we need to use ``multiple_returns=True`` in our use of registering the operator. In this case, ``outs`` will be a list containing multiple arrays, not an array.\n", "\n", @@ -149,8 +147,10 @@ " return c, d\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # take out all the outputs\n", - " a, b = ins # take out all the inputs\n", + " c = outs[0] # take out all the outputs\n", + " d = outs[1]\n", + " a = ins[0] # take out all the inputs\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -170,7 +170,7 @@ "collapsed": false }, "source": [ - "### Non-Tracer parameters\n", + "#### Non-Tracer parameters\n", "\n", "In the ``eval_shape`` function, all arguments are abstract information (containing only the shape and type) if they are arguments that can be traced by ``jax.jit``. However, if we infer the output data type requires additional information beyond the input parameter information, then we need to define non-Tracer parameters.\n", "\n", @@ -191,7 +191,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # Take out all the outputs\n", - " a, b = ins # Take out all inputs\n", + " a = ins[0] # Take out all inputs\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -221,7 +222,7 @@ "collapsed": false }, "source": [ - "### Example: A sparse operator\n", + "#### Example: A sparse operator\n", "\n", "To illustrate the effectiveness of this approach, we define in this an event-driven sparse computation operator." ] @@ -297,6 +298,50 @@ "f(1.)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### brainpy.math.XLACustomOp\n", + "\n", + "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using @numba.jit or @numba.njit, and then pass the kernel to `brainpy.math.XLACustomOp`.\n", + "\n", + "Detailed steps are as follows:\n", + "\n", + "#### Define the kernel\n", + "\n", + "```python\n", + "@numba.njit(fastmath=True)\n", + "def numba_event_csrmv(weight, indices, vector, outs):\n", + " outs.fill(0)\n", + " weight = weight[()] # 0d\n", + " for row_i in range(vector.shape[0]):\n", + " if vector[row_i]:\n", + " for j in indices[row_i]:\n", + " outs[j] += weight\n", + "```\n", + "\n", + "In the declaration of parameters, the last few parameters need to be output parameters so that numba can compile correctly. This operator numba_event_csrmv receives four parameters: `weight`, `indices`, `vector`, and `outs`. The first three parameters are input parameters, and the last parameter is the output parameter. The output parameter is a 1D array, and the input parameters are 0D, 1D, and 2D arrays, respectively." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Registering and Using Custom Operators\n", + "After defining a custom operator, it can be registered into a specific framework and used where needed. When registering, you can specify cpu_kernel and gpu_kernel, so the operator can run on different devices. Specify the outs parameter when calling, using `jax.ShapeDtypeStruct` to define the shape and data type of the output.\n", + "\n", + "Note: Maintain the order of the operator's declared parameters consistent with the order when calling.\n", + "\n", + "```python\n", + "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n", + "indices = bm.random.randint(0, s, (s, 80))\n", + "vector = bm.random.rand(s) < 0.1\n", + "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n", + "print(out)\n", + "```" + ] + }, { "cell_type": "markdown", "metadata": { @@ -423,7 +468,7 @@ "collapsed": false }, "source": [ - "### 返回多个值 ``multiple_returns=True``\n", + "#### 返回多个值 ``multiple_returns=True``\n", "\n", "如果我们的计算结果需要返回多个数组,那么,我们在注册算子的使用需要使用``multiple_returns=True``。此时,``outs``将会是一个包含多个数组的列表,而不是一个数组。\n", "\n", @@ -434,8 +479,10 @@ " return c, d # 返回多个抽象数组信息\n", "\n", "def con_compute2(outs, ins):\n", - " c, d = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " c = outs[0] # 取出所有的输出\n", + " d = outs[1]\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = a + 1\n", " d[:] = a * 2\n", "\n", @@ -455,7 +502,7 @@ "collapsed": false }, "source": [ - "### 非Tracer参数\n", + "#### 非Tracer参数\n", "\n", "在``eval_shape``函数中推断数据类型时,如果所有参数都是可以被``jax.jit``追踪的参数,那么所有参数都是抽象信息(只包含形状和类型)。如果有时推断输出数据类型时还需要除输入参数信息以外的额外信息,此时我们需要定义非Tracer参数。\n", "\n", @@ -476,7 +523,8 @@ "\n", "def con_compute3(outs, ins):\n", " c = outs # 取出所有的输出\n", - " a, b = ins # 取出所有的输入\n", + " a = ins[0] # 取出所有的输入\n", + " b = ins[1]\n", " c[:] = 2.\n", "\n", "op3 = bm.CustomOpByNumba(eval_shape3, con_compute3, multiple_results=False)\n", @@ -506,7 +554,7 @@ "collapsed": false }, "source": [ - "### 示例:一个稀疏算子\n", + "#### 示例:一个稀疏算子\n", "\n", "为了说明这种方法的有效性,我们在这个定义一个事件驱动的稀疏计算算子。" ] @@ -581,6 +629,50 @@ "f = jit(lambda a: sparse_cus_op(a, sparse_A[0], sparse_A[1], vector, shape=(size, size)))\n", "f(1.)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### brainpy.math.XLACustomOp\n", + "\n", + "`brainpy.math.XLACustomOp` is a new method for customizing operators on the CPU device. It is similar to `brainpy.math.CustomOpByNumba`, but it is more flexible and supports more advanced features. If you want to use this new method with numba, you only need to define a kernel using `@numba.jit` or `@numba.njit` decorator, and then pass the kernel to `brainpy.math.XLACustomOp`.\n", + "`brainpy.math.XLACustomOp`是一种自定义算子的新方法。它类似于`brainpy.math.CustomOpByNumba`,但它更灵活并支持更高级的特性。如果您想用numba使用这种新方法,只需要使用 `@numba.jit`或`@numba.njit`装饰器定义一个kernel,然后将内核传递给`brainpy.math.XLACustomOp`。\n", + "\n", + "详细步骤如下:\n", + "\n", + "#### 定义kernel\n", + "在参数声明中,最后几个参数需要是输出参数,这样numba才能正确编译。这个算子`numba_event_csrmv`接受四个参数:weight、indices、vector 和 outs。前三个参数是输入参数,最后一个参数是输出参数。输出参数是一个一维数组,输入参数分别是 0D、1D 和 2D 数组。\n", + "\n", + "```python\n", + "@numba.njit(fastmath=True)\n", + "def numba_event_csrmv(weight, indices, vector, outs):\n", + " outs.fill(0)\n", + " weight = weight[()] # 0d\n", + " for row_i in range(vector.shape[0]):\n", + " if vector[row_i]:\n", + " for j in indices[row_i]:\n", + " outs[j] += weight\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 注册并使用自定义算子\n", + "在定义了自定义算子之后,可以将其注册到特定框架中,并在需要的地方使用它。在注册时可以指定`cpu_kernel`和`gpu_kernel`,这样算子就可以在不同的设备上运行。并在调用中指定`outs`参数,用`jax.ShapeDtypeStruct`来指定输出的形状和数据类型。\n", + "\n", + "注意: 在算子声明的参数与调用时需要保持顺序的一致。\n", + "\n", + "```python\n", + "prim = bm.XLACustomOp(cpu_kernel=numba_event_csrmv)\n", + "indices = bm.random.randint(0, s, (s, 80))\n", + "vector = bm.random.rand(s) < 0.1\n", + "out = prim(1., indices, vector, outs=[jax.ShapeDtypeStruct([s], dtype=bm.float32)])\n", + "print(out)\n", + "```" + ] } ], "metadata": {