From 4edeeaf6950440d95ba4fd4fdcdda48625431442 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Tue, 23 Jul 2024 10:47:04 +0800 Subject: [PATCH 1/3] im2col parameter bug: use SRC_UNIT>pack to decide weight&input shape. Signed-off-by: jingbang.yjb --- .../cpu/compute/ConvInt8TiledExecutor.cpp | 26 +++++++++++-------- .../cpu/compute/ConvInt8TiledExecutor.hpp | 2 +- .../cpu/compute/ConvolutionTiledExecutor.cpp | 7 +---- .../backend/cpu/compute/GemmInt8Executor.cpp | 4 +-- .../cpu/compute/IdstConvolutionInt8.cpp | 4 +-- 5 files changed, 21 insertions(+), 22 deletions(-) diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index 756f24aee..6471acb3a 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -37,21 +37,21 @@ ErrorCode ConvInt8TiledExecutor::onResize(const std::vector& inputs, co return NO_ERROR; } -void ConvInt8TiledExecutor::reorderWeight(Tensor* weight, const uint8_t* weightSrc, int SRC_UNIT, int UNIT, int ic, int oc, int kernelCount) { +void ConvInt8TiledExecutor::reorderWeight(Tensor* weight, const uint8_t* weightSrc, int SRC_UNIT, int UNIT, int ic, int oc, int kernelCount, int pack) { auto weightDst = weight->host(); memset(weightDst, 0, weight->size()); - if (SRC_UNIT > UNIT) { - auto icDivU = UP_DIV(ic, UNIT); + if (SRC_UNIT > pack) { + auto icDivU = UP_DIV(ic, pack); for (int k = 0; k < kernelCount; ++k) { const auto srcK = weightSrc + k; for (int y = 0; y < ic; ++y) { - const int yOutSide = y / UNIT; - const int yInSide = y % UNIT; + const int yOutSide = y / pack; + const int yInSide = y % pack; const int yIndex = yOutSide + k * icDivU; - const int ySubOutSide = yIndex / (SRC_UNIT / UNIT); - const int ySubInSide = yIndex % (SRC_UNIT / UNIT); + const int ySubOutSide = yIndex / (SRC_UNIT / pack); + const int ySubInSide = yIndex % (SRC_UNIT / pack); - auto dstY = weightDst + ySubOutSide * weight->stride(1) + ySubInSide * UNIT + yInSide; + auto dstY = weightDst + ySubOutSide * weight->stride(1) + ySubInSide * pack + yInSide; const auto srcY = srcK + y * kernelCount; for (int x = 0; x < oc; ++x) { const int xOutSide = x / UNIT; @@ -94,9 +94,13 @@ static bool _reorderWeightInside(Backend* bn, const Convolution2DCommon* common, // reorder weight, [oc, ic, k^2] => [oc/unit, ((ic/unit)*k^2)/(src_unit/unit), unit(oc), (src_unit/unit), unit(ic)] int oc = common->outputCount(), ic = common->inputCount(), kernelCount = common->kernelX() * common->kernelY(); std::vector shape; - if (SRC_UNIT > UNIT) { + int pack = gcore->pack; + if (gcore->bytes == 2 && gcore->pack == 8) { + pack = 4; + } + if (SRC_UNIT > pack) { MNN_ASSERT(SRC_UNIT % UNIT == 0); - shape = {UP_DIV(oc, UNIT), UP_DIV(UP_DIV(ic, UNIT) * kernelCount, SRC_UNIT / UNIT), UNIT, SRC_UNIT}; + shape = {UP_DIV(oc, UNIT), UP_DIV(UP_DIV(ic, pack) * kernelCount, SRC_UNIT / pack), UNIT, SRC_UNIT}; } else { shape = {UP_DIV(oc, UNIT), UP_DIV(ic, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT}; } @@ -108,7 +112,7 @@ static bool _reorderWeightInside(Backend* bn, const Convolution2DCommon* common, MNN_ERROR("Memory not enough"); return false; } - ConvInt8TiledExecutor::reorderWeight(weight.get(), weightOrigin->host(), SRC_UNIT, UNIT, ic, oc, kernelCount); + ConvInt8TiledExecutor::reorderWeight(weight.get(), weightOrigin->host(), SRC_UNIT, UNIT, ic, oc, kernelCount, pack); return true; } diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp index ec2d78393..d4524837c 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp @@ -23,7 +23,7 @@ class ConvInt8TiledExecutor : public CPUConvolution { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; virtual void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) = 0; - static void reorderWeight(Tensor* weight, const uint8_t* weightSrc, int SRC_UNIT, int UNIT, int ic, int oc, int kernelCount); + static void reorderWeight(Tensor* weight, const uint8_t* weightSrc, int SRC_UNIT, int UNIT, int ic, int oc, int kernelCount, int pack); protected: ConvolutionCommon::Im2ColParameter mIm2ColParamter; diff --git a/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp b/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp index e2e0f16bc..5b3adc2e1 100644 --- a/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp @@ -122,12 +122,7 @@ void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColPara int UNIT, SRC_UNIT, DynamicDestUnit; auto core = int8Core; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DynamicDestUnit); - if (floatCore->bytes == 2 && DynamicDestUnit == 20) { - UNIT = 8; - SRC_UNIT= 8; - DynamicDestUnit = 10; - } - if (SRC_UNIT > UNIT) { + if (SRC_UNIT > pack) { const auto srcCountUnit = UP_DIV(input->channel(), pack); dstIm2ColParamter.kernelCountUnit = UP_DIV(srcCountUnit * kernelCount, SRC_UNIT / pack); dstIm2ColParamter.ic = dstIm2ColParamter.icDiv4 * pack; diff --git a/source/backend/cpu/compute/GemmInt8Executor.cpp b/source/backend/cpu/compute/GemmInt8Executor.cpp index 00e501e5d..bc5abc93b 100644 --- a/source/backend/cpu/compute/GemmInt8Executor.cpp +++ b/source/backend/cpu/compute/GemmInt8Executor.cpp @@ -82,12 +82,12 @@ ErrorCode GemmInt8Executor::onResize(const std::vector &inputs, const mIm2ColParamter.padX = 0; mIm2ColParamter.padY = 0; mIm2ColParamter.kernelCountUnit = UP_DIV(input->channel(), SRC_UNIT); - if (SRC_UNIT > UNIT___) { + if (SRC_UNIT > UNIT___ && UNIT___ == pack) { const auto srcCountUnit = UP_DIV(input->channel(), pack); mIm2ColParamter.ic = mIm2ColParamter.icDiv4 * pack; } else { const auto srcCountUnit = UP_DIV(input->channel(), SRC_UNIT); - mIm2ColParamter.ic = srcCountUnit * SRC_UNIT; + mIm2ColParamter.ic = mIm2ColParamter.icDiv4 * pack; } mTileCnt = UP_DIV(input->height() * input->width() * input->batch(), DST_XUNIT); diff --git a/source/backend/cpu/compute/IdstConvolutionInt8.cpp b/source/backend/cpu/compute/IdstConvolutionInt8.cpp index bec8d7109..025ed8763 100644 --- a/source/backend/cpu/compute/IdstConvolutionInt8.cpp +++ b/source/backend/cpu/compute/IdstConvolutionInt8.cpp @@ -65,7 +65,7 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back auto kernelCount = kx * ky; auto srcCount = mSrcCount; std::vector shape; - if (SRC_UNIT > UNIT) { + if (SRC_UNIT > UNIT && UNIT == PackUnit) { MNN_ASSERT(SRC_UNIT % UNIT == 0); shape = {UP_DIV(outputCount, UNIT), UP_DIV(UP_DIV(srcCount, UNIT) * kernelCount, SRC_UNIT / UNIT), UNIT, SRC_UNIT}; } else { @@ -81,7 +81,7 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back MNN_ERROR("Memory not enough\n"); return; } - ConvInt8TiledExecutor::reorderWeight(mWeight.get(), (uint8_t*)common->weight.get(), SRC_UNIT, UNIT, srcCount, outputCount, kernelCount); + ConvInt8TiledExecutor::reorderWeight(mWeight.get(), (uint8_t*)common->weight.get(), SRC_UNIT, UNIT, srcCount, outputCount, kernelCount, PackUnit); ::memset(mFakeBias->host(), 0, mFakeBias->size()); ::memset(mFakeWeightBias->host(), 0, mFakeWeightBias->size()); #ifdef MNN_USE_SSE From 37e6192e8265ac60229dfa1bfdf40701f4821c40 Mon Sep 17 00:00:00 2001 From: "jingbang.yjb" Date: Mon, 22 Jul 2024 13:39:42 +0800 Subject: [PATCH 2/3] Fix Nearest and Bilinear sampler compute error and add unit tests. Signed-off-by: jingbang.yjb --- .../cpu/arm/arm64/MNNSamplerC4BilinearOpt.S | 46 ++++---- .../cpu/arm/arm64/MNNSamplerC4NearestOpt.S | 10 +- test/cv/ImageProcessTest.cpp | 108 ++++++++++++++++++ 3 files changed, 136 insertions(+), 28 deletions(-) diff --git a/source/backend/cpu/arm/arm64/MNNSamplerC4BilinearOpt.S b/source/backend/cpu/arm/arm64/MNNSamplerC4BilinearOpt.S index e755066fc..cb0757814 100644 --- a/source/backend/cpu/arm/arm64/MNNSamplerC4BilinearOpt.S +++ b/source/backend/cpu/arm/arm64/MNNSamplerC4BilinearOpt.S @@ -90,31 +90,31 @@ L1: cmp x3, #0 beq End mov v16.s[0], w4 -mov v16.s[1], w5 +mov v16.s[1], w5 // v16:[xMax, yMax] mov w12, #4 -mov v7.s[0], w12 -mov v7.s[1], w6 +mov v7.s[0], w12 // bpp=4 +mov v7.s[1], w6 // yStride dup v20.2d, x0 L1Loop: -fcvtms v2.2s, v0.2s +fcvtzs v2.2s, v0.2s // [x0, y0] frintm v4.2s, v0.2s -smax v2.2s, v2.2s, v19.2s -fcvtps v3.2s, v0.2s -fabd v4.2s, v0.2s, v4.2s +smax v2.2s, v2.2s, v19.2s // max(0, y) +fcvtps v3.2s, v0.2s // [x1, y1] +fabd v4.2s, v0.2s, v4.2s // (xF, yF) smax v3.2s, v3.2s, v19.2s smin v2.2s, v2.2s, v16.2s smin v3.2s, v3.2s, v16.2s -mul v2.2s, v2.2s, v7.2s -mul v3.2s, v3.2s, v7.2s -mov v2.s[2], v3.s[0] -mov v3.s[2], v2.s[0] +mul v2.2s, v2.2s, v7.2s // [bpp * x0, y0 * yStride] +mul v3.2s, v3.2s, v7.2s // [bpp * x1, y1 * yStride] +mov v2.s[2], v3.s[0] // v2: [bpp*x0, y0*yStride, bpp*x1, y0*yStride] +mov v3.s[2], v2.s[0] // v3: [bpp*x1, y1*yStride, bpp*x0, y1*yStride] mov v2.s[3], v2.s[1] mov v3.s[3], v3.s[1] -uaddlp v2.2d, v2.4s -uaddlp v3.2d, v3.4s +uaddlp v2.2d, v2.4s // [c00, c01] +uaddlp v3.2d, v3.4s // [c11, c10] add v2.2d, v20.2d, v2.2d add v3.2d, v20.2d, v3.2d @@ -131,25 +131,25 @@ uxtl v6.8h, v6.8b //Now v2, v3 is of no use //v2: LT, v3: RT, v5: LB, v6:BT -uxtl v2.4s, v5.4h -uxtl2 v3.4s, v5.8h +uxtl v2.4s, v5.4h // c00 +uxtl2 v3.4s, v5.8h // c01 ucvtf v2.4s, v2.4s -uxtl v5.4s, v6.4h +uxtl v5.4s, v6.4h // c11 ucvtf v3.4s, v3.4s -uxtl2 v6.4s, v6.8h +uxtl2 v6.4s, v6.8h // c10 ucvtf v5.4s, v5.4s ucvtf v6.4s, v6.4s fsub v3.4s, v3.4s, v2.4s -fsub v6.4s, v6.4s, v5.4s -fmla v2.4s, v3.4s, v4.s[0] -fmla v5.4s, v6.4s, v4.s[0] +fsub v5.4s, v5.4s, v6.4s +fmla v2.4s, v3.4s, v4.s[0] // (c01-c00)*xF+c00 +fmla v6.4s, v5.4s, v4.s[0] // (c11-c10)*xF+c10 -fsub v5.4s, v5.4s, v2.4s -fmla v2.4s, v5.4s, v4.s[1] +fsub v6.4s, v6.4s, v2.4s +fmla v2.4s, v6.4s, v4.s[1] -fcvtns v2.4s, v2.4s +fcvtzs v2.4s, v2.4s uqxtn v2.4h, v2.4s uqxtn v2.8b, v2.8h diff --git a/source/backend/cpu/arm/arm64/MNNSamplerC4NearestOpt.S b/source/backend/cpu/arm/arm64/MNNSamplerC4NearestOpt.S index fc0bc8d77..da7d12d92 100644 --- a/source/backend/cpu/arm/arm64/MNNSamplerC4NearestOpt.S +++ b/source/backend/cpu/arm/arm64/MNNSamplerC4NearestOpt.S @@ -44,13 +44,13 @@ mov v5.s[2], v3.s[1] mov v4.s[3], v2.s[0] mov v5.s[3], v2.s[1] -dup v23.4s, w6 +dup v23.4s, w6 // yStride movi v24.4s, #4 dup v22.2d, x0 L4Loop: -fcvtns v6.4s, v4.4s -fcvtns v7.4s, v5.4s +fcvtas v6.4s, v4.4s // x +fcvtas v7.4s, v5.4s // y smin v6.4s, v6.4s, v16.4s smin v7.4s, v7.4s, v17.4s @@ -58,7 +58,7 @@ smax v6.4s, v6.4s, v19.4s smax v7.4s, v7.4s, v19.4s mul v7.4s, v7.4s, v23.4s -mla v7.4s, v6.4s, v24.4s +mla v7.4s, v6.4s, v24.4s // offset = y * yStride + bpp * x uxtl v6.2d, v7.2s uxtl2 v7.2d, v7.4s add v6.2d, v6.2d, v22.2d @@ -95,7 +95,7 @@ mov w12, #4 L1Loop: -fcvtns v2.2s, v0.2s +fcvtas v2.2s, v0.2s smax v2.2s, v2.2s, v19.2s smin v2.2s, v2.2s, v6.2s mov w4, v2.s[0] diff --git a/test/cv/ImageProcessTest.cpp b/test/cv/ImageProcessTest.cpp index 936b39fdb..7f689fac5 100644 --- a/test/cv/ImageProcessTest.cpp +++ b/test/cv/ImageProcessTest.cpp @@ -696,3 +696,111 @@ class ImageProcessYUVBlitterTest : public ImageProcessYUVTestCommmon { }; // {YUV_NV21, YUV_NV12, YUV_I420} -> {RGBA, RGB, BGRA, BGR, GRAY} unit test MNNTestSuiteRegister(ImageProcessYUVBlitterTest, "cv/image_process/yuv_blitter"); + +static bool funcToColorResize(int iw, int ih, int ic, int ow, int oh, int oc, Filter filtertype, ImageFormat srcFormat, ImageFormat dstFormat) { + auto srcImg = genSourceData(ih, iw, ic); + auto dstType = halide_type_of(); + + float fx = static_cast(iw) / ow; + float fy = static_cast(ih) / oh; + ImageProcess::Config config0, config1; + + // resize first + config0.sourceFormat = srcFormat; + config0.destFormat = srcFormat; + config0.filterType = filtertype; + std::unique_ptr process0(ImageProcess::create(config0)); + auto resizeTensor = Tensor::create({1, oh, ow, ic}, dstType); + Matrix tr; + tr.postScale(fx, fy); + tr.postTranslate(0.5 * (fx - 1), 0.5 * (fy - 1)); + process0->setMatrix(tr); + process0->convert(srcImg.data(), iw, ih, 0, resizeTensor->host(), ow, oh, ic, 0, dstType); + + // then convert color + config1.sourceFormat = srcFormat; + config1.destFormat = dstFormat; + config1.filterType = filtertype; + std::unique_ptr process1(ImageProcess::create(config1)); + auto colorTensor = Tensor::create({1, oh, ow, oc}, dstType); + Matrix tr1; + tr1.postScale(1.f, 1.f); + tr1.postTranslate(0, 0); + process1->setMatrix(tr1); + process1->convert(resizeTensor->host(), ow, oh, 0, colorTensor->host(), ow, oh, oc, 0, dstType); + + // convert color first + ImageProcess::Config config2, config3; + config2.sourceFormat = srcFormat; + config2.destFormat = dstFormat; + config2.filterType = filtertype; + + std::unique_ptr process2(ImageProcess::create(config2)); + auto colorTensor2 = Tensor::create({1, ih, iw, oc}, dstType); + Matrix tr2; + tr2.postScale(1.f, 1.f); + tr2.postTranslate(0.f, 0.f); + process2->setMatrix(tr2); + process2->convert(srcImg.data(), iw, ih, 0, colorTensor2->host(), iw, ih, oc, 0, dstType); + + // Second: resize + config3.sourceFormat = dstFormat; + config3.destFormat = dstFormat; + config3.filterType = filtertype; + + std::unique_ptr process3(ImageProcess::create(config3)); + auto resizeTensor3 = Tensor::create({1, oh, ow, oc}, dstType); + Matrix tr3; + tr3.postScale(fx, fy); + tr3.postTranslate(0.5 * (fx - 1), 0.5 * (fy - 1)); + process3->setMatrix(tr3); + process3->convert(colorTensor2->host(), iw, ih, 0, resizeTensor3->host(), ow, oh, oc, 0, dstType); + + // compare these two results + auto res1Ptr = colorTensor->host(); + auto res2Ptr = resizeTensor3->host(); + auto size_ = resizeTensor3->size(); + for (int i = 0; i < (int)size_; ++i) { + if (res1Ptr[i] != res2Ptr[i]) { + return false; + } + } + return true; +} + +class ImageProcessColorResizeTest: public MNNTestCase { + // Test: first color then resize and first resize then color, these two results are same. + virtual ~ImageProcessColorResizeTest() = default; + virtual bool run(int precison) { + std::vector filters(NEAREST, BILINEAR); + for (int iw = 2; iw < 200; iw += 17) { + for (int ih = 7; ih < 200; ih += 19) { + for (int ow = 2; ow < 200; ow += 17) { + for (int oh = 8; oh < 240; oh += 30) { + for (int f = 0; f < filters.size(); ++f) { + int ic = 4; + int oc = 3; + bool res = funcToColorResize(iw, ih, ic, ow, oh, oc, filters[f], RGBA, RGB); + if (!res) { + MNN_PRINT("iw=%d, ih=%d, ic=%d, ow=%d, oh=%d, oc=%d, filtertype=%d, RGBA->RGB\n", iw, ih, ic, ow, oh, oc, filters[f]); + return false; + } + ic = 3; + oc = 4; + res &= funcToColorResize(iw, ih, ic, ow, oh, oc, filters[f], RGB, RGBA); + if (!res) { + MNN_PRINT("iw=%d, ih=%d, ic=%d, ow=%d, oh=%d, oc=%d, filtertype=%d, RGB->RGBA\n", iw, ih, ic, ow, oh, oc, filters[f]); + return false; + } + + } + + } + } + } + } + return true; + } +}; +MNNTestSuiteRegister(ImageProcessColorResizeTest, "cv/image_process/color_resize_test"); + From 602b2e22a470ab7399fb2576ba989d001049425a Mon Sep 17 00:00:00 2001 From: xiaying Date: Thu, 25 Jul 2024 12:06:03 +0800 Subject: [PATCH 3/3] Doc:Feature: Add correct check in faq --- docs/faq.md | 13 ++++++++++++- docs/tools/convert.md | 8 +++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/faq.md b/docs/faq.md index d21236df3..c9b6344c0 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -117,7 +117,15 @@ opConverter ==> MNN Converter NOT_SUPPORTED_OP: [ ANY_OP_NAME ] 临时解决方案:升级 numpy 版本到 1.20.0 或以上 ## 运行问题 -### 运行结果出错 / Tensor 的 elementSize 不为各维度乘积 +### 运行结果出错 +- 先使用 testMNNFromOnnx.py 等测试工具进行测试,具体参见模型转换工具的正确性校验部分 +- 测试工具验证正确,但运行代码结果出错,可能是如下原因: + 1. 使用 Session API 运行不满足运行条件的模型,此时应换用 Module API + 2. 输入的内存布局不对 + 3. 输入数据格式不对,int64 需要换成 int32_t ,double 需要换成 float + + +### 布局转换问题(Tensor 的 elementSize 不为各维度乘积) MNN 内部对 CV 相关算子采用 NC4HW4 布局,计算 elementSize 时,channel 会上对齐到 4 返回,此内存布局允许实现的硬件自行确定内存排列方式,具体方式不对用户可见,但用户可以通过如下代码,输入或获取自己指定的NCHW / NHWC 内存布局的 Tensor / VARP。 #### Interpreter-Session API @@ -237,6 +245,9 @@ OpenCL / Vulkan 采用静态变量自注册的方式往 MNN 主库注册后端. - 目前支持OpenCL和CUDA后端进行设置 - 具体可以参考:tools/cpp/testModel.cpp +### Register 相关内存泄露说明 +用 valgrind 工具检查时会报 MNN Register 相关的内存泄露,这个属于一次性的初始化内存,后续也不会增加,可视为误报 + ## 性能相关 ### 使用 GPU 时,调用 copyToHostTensor / copyFromHostTensor 非常慢 diff --git a/docs/tools/convert.md b/docs/tools/convert.md index f8ab79101..fdc707bc1 100644 --- a/docs/tools/convert.md +++ b/docs/tools/convert.md @@ -145,7 +145,13 @@ model_script.save('model_script.pt') - testMNNFromOnnx.py :适用 onnx - testMNNFromTorch.py :适用 pt (torchscript) -注意:对于由Torchscript转换的模型,需要自行修改`testMNNFromTorch.py`中的的输入信息来测试 +注意: + +- 如果模型是动态输入形状,MNN 在脚本中默认不固定部分为1,有可能在 Tensorflow / OnnxRuntime / Torch 验证阶段报错。此时需要修改脚本中对应的输入部分,比如 testMNNFromOnnx.py 中的 run_onnx(self) 函数,把输入替换为有效的输入形状和内容。 +- 对于由Torchscript转换的模型,一般都需要自行修改`testMNNFromTorch.py`中的的输入信息来测试。 +- 如果模型输出层是 Identity 产生的,会因为 MNN 图优化的缘故丢失,此时需要校验上一层的输出,即在脚本后接输出名来测试,如: python3 ../tools/scripts/testMNNFromTf.py XXX.pb $NAME$ + + ### 前置 - 测试 pb / tflite :安装`tensorflow`(`pip install tensorflow`) - 测试 onnx : 安装`onnxruntime`(`pip install onnxruntime`)