From 06dd1219274bb9bb8a9c5cfb2a92c19c518339e4 Mon Sep 17 00:00:00 2001 From: Hans Date: Mon, 15 Apr 2024 15:06:33 +0800 Subject: [PATCH 1/2] MatMul A matrix support dynamic shape and max 3 rank --- .../core/providers/xnnpack/math/matmul.cc | 58 +++++++++++++------ 1 file changed, 39 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index e90aa11c9d087..1f42217769205 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -31,22 +31,30 @@ bool MatMul::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& g // Support only float const auto* A_type = A_arg.TypeAsProto(); + if (A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + break; + } + const auto* A_shape = A_arg.Shape(); const auto* B_shape = B_arg.Shape(); - if (A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { + if (A_shape == nullptr || B_shape == nullptr) { + break; + } + + size_t A_rank = A_shape->dim_size(); + size_t B_rank = B_shape->dim_size(); + + // Support A [M, K] or [batch, M, K] x B [K, N] or [N] + if (B_rank > 2 || (A_rank != B_rank && A_rank != B_rank + 1)) { break; } - if (A_shape == nullptr || A_shape->dim_size() > 2 || - (A_shape->dim_size() == 2 && A_shape->dim(1).dim_value() == 0) || - A_shape->dim(0).dim_value() == 0) { + if (B_shape->dim(0).dim_value() == 0) { break; } - if (B_shape == nullptr || B_shape->dim_size() > 2 || - (B_shape->dim_size() == 2 && B_shape->dim(1).dim_value() == 0) || - B_shape->dim(0).dim_value() == 0) { + if (B_rank == 2 && B_shape->dim(1).dim_value() == 0) { break; } @@ -128,20 +136,32 @@ Status MatMul::Compute(OpKernelContext* ctx) const { auto* y_data = y->MutableData(); - xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool); - if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status); - } - - status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data(), y_data); - if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status); + xnn_status status = xnn_status::xnn_status_uninitialized; + auto a_shape = a->Shape(); + + if (a_shape.NumDimensions() != b_shape_.NumDimensions()) { + // A is [batch, ..., K] and B is [K, N] output is [batch, ..., N] + size_t batch_size = a_shape[0]; + size_t M = a_shape[1]; + for (size_t i = 0; i < batch_size; i++) { + size_t offset = i * M; + status = xnn_reshape_fully_connected_nc_f32(op0_.get(), M, threadpool); + ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_reshape_fully_connected_nc_f32 returned ", status); + status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data() + offset, y_data + offset); + ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_setup_fully_connected_nc_f32 returned ", status); + status = xnn_run_operator(op0_.get(), nullptr); + ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_run_operator returned ", status); + } + } else { + // A is [M, K] and B is [K, N] + status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool); + ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_reshape_fully_connected_nc_f32 returned ", status); + status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data(), y_data); + ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_setup_fully_connected_nc_f32 returned ", status); + status = xnn_run_operator(op0_.get(), nullptr); + ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_run_operator returned ", status); } - status = xnn_run_operator(op0_.get(), nullptr); - if (status != xnn_status_success) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status); - } return Status::OK(); } From 7cc905c2d6e1c37f9ac188fd48785fcdf0b75ba2 Mon Sep 17 00:00:00 2001 From: Hans Date: Mon, 15 Apr 2024 15:06:50 +0800 Subject: [PATCH 2/2] Support MatMulInteger --- .../xnnpack/detail/node_support_checker.cc | 1 + .../core/providers/xnnpack/math/matmul_int.cc | 322 ++++++++++++++++++ .../core/providers/xnnpack/math/matmul_int.h | 57 ++++ .../xnnpack/xnnpack_execution_provider.cc | 3 + 4 files changed, 383 insertions(+) create mode 100644 onnxruntime/core/providers/xnnpack/math/matmul_int.cc create mode 100644 onnxruntime/core/providers/xnnpack/math/matmul_int.h diff --git a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc index e2d71cda68ec4..b293b14347945 100644 --- a/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc +++ b/onnxruntime/core/providers/xnnpack/detail/node_support_checker.cc @@ -17,6 +17,7 @@ // each operator provides a helper to check if supported #include "core/providers/xnnpack/math/gemm.h" #include "core/providers/xnnpack/math/matmul.h" +#include "core/providers/xnnpack/math/matmul_int.h" #include "core/providers/xnnpack/math/softmax.h" #include "core/providers/xnnpack/nn/average_pool.h" #include "core/providers/xnnpack/nn/conv.h" diff --git a/onnxruntime/core/providers/xnnpack/math/matmul_int.cc b/onnxruntime/core/providers/xnnpack/math/matmul_int.cc new file mode 100644 index 0000000000000..4f05b9c7a743a --- /dev/null +++ b/onnxruntime/core/providers/xnnpack/math/matmul_int.cc @@ -0,0 +1,322 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "matmul_int.h" +#include "core/providers/cpu/math/matmul_helper.h" + +// Todo - +// 1. Integrate activation layers - Cliping & Relu +// 2. Enable Quant ops +// 3. Review possible consolidation of MatMul & Gemm +// + +namespace onnxruntime { +namespace xnnpack { + +bool MatMulIntegerCommon::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph) { + bool supported = false; + const onnxruntime::Node& node = node_unit.GetNode(); + + // use do {} while(false) so it's easier to set a breakpoint on the return + do { + auto input_defs = node.InputDefs(); + + if (input_defs.size() < 2) { + break; + } + + const auto& A_arg = *input_defs[0]; + const auto& B_arg = *input_defs[1]; + + const auto* A_shape = A_arg.Shape(); + const auto* B_shape = B_arg.Shape(); + + if (A_shape == nullptr || B_shape == nullptr) { + break; + } + + size_t A_rank = A_shape->dim_size(); + size_t B_rank = B_shape->dim_size(); + + // Support A [M, K] or [batch, M, K] x B [K, N] or [N] + if (B_rank > 2 || (A_rank != B_rank && A_rank != B_rank + 1)) { + break; + } + + if (B_shape->dim(0).dim_value() == 0) { + break; + } + + if (B_rank == 2 && B_shape->dim(1).dim_value() == 0) { + break; + } + + // B matrix must be constant + if (!graph.IsConstantInitializer(B_arg.Name(), true)) { + break; + } + + // b_zero_point must be constant + if (input_defs.size() > 3) { + if (!graph.IsConstantInitializer(input_defs[3]->Name(), true)) { + break; + } + } + + supported = true; + + } while (false); + + return supported; +} + +Status MatMulInteger::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /*Not used*/) { + is_packed = false; + + if (input_idx == 0) { + return Status::OK(); + } + + if (input_idx == 1) { + b_shape_ = tensor.Shape(); + B_ = &tensor; + } + + if (input_idx == 2 && has_a_zero_point_) { + a_zero_point_ = tensor.Data()[0]; + } + + if (input_idx == 3 && has_a_zero_point_) { + b_zero_point_ = tensor.Data()[0]; + } + + if ((!has_a_zero_point_ || input_idx >= 2) && (!has_b_zero_point_ || input_idx >= 3)) { + myAlloc = alloc; + + uint32_t flags = XNN_FLAG_TRANSPOSE_WEIGHTS; + xnn_status status = xnn_status::xnn_status_uninitialized; + struct xnn_operator* p = nullptr; + auto shape_broadcast = b_shape_.AsShapeVector(); + if (b_shape_.NumDimensions() == 1) { + shape_broadcast.push_back(1); + } + status = xnn_create_fully_connected_nc_qs8( + shape_broadcast[0], // size_t input_channels, + shape_broadcast[1], // size_t output_channels, + shape_broadcast[0], // size_t input_stride, + shape_broadcast[1], // size_t output_stride, + a_zero_point_, // int8_t input_zero_point, + 1.0f, // float input_scale, + 1.0f, // float kernel_scale, + B_->Data(), // const int8_t* kernel, + nullptr, // const int32_t* bias, + 0, // int8_t output_zero_point, + 1.0f, // float output_scale, + INT8_MIN, + INT8_MAX, + flags, +#ifdef XNN_CACHE_ENABLE + GetCodeCache(), + GetWeightsCache(), +#else + nullptr, + nullptr, +#endif + &p); + + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_create_fully_connected_nc_qs8 returned ", status); + + op0_.reset(p); + } + + is_packed = true; + + return Status::OK(); +} + +Status MatMulInteger::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /*Not used*/) { + is_packed = false; + + if (input_idx == 0) { + return Status::OK(); + } + + if (input_idx == 1) { + b_shape_ = tensor.Shape(); + B_ = &tensor; + } + + if (input_idx == 2 && has_a_zero_point_) { + a_zero_point_ = tensor.Data()[0]; + } + + if (input_idx == 3 && has_a_zero_point_) { + b_zero_point_ = tensor.Data()[0]; + } + + if (input_idx >= 1 && (!has_a_zero_point_ || input_idx >= 2) && (!has_b_zero_point_ || input_idx >= 3)) { + myAlloc = alloc; + + uint32_t flags = XNN_FLAG_TRANSPOSE_WEIGHTS; + xnn_status status = xnn_status::xnn_status_uninitialized; + + struct xnn_operator* p = nullptr; + auto shape_broadcast = b_shape_.AsShapeVector(); + if (b_shape_.NumDimensions() == 1) { + shape_broadcast.push_back(1); + } + status = xnn_create_fully_connected_nc_qu8( + shape_broadcast[0], // size_t input_channels, + shape_broadcast[1], // size_t output_channels, + shape_broadcast[0], // size_t input_stride, + shape_broadcast[1], // size_t output_stride, + a_zero_point_, // uint8_t input_zero_point, + 1.0f, // float input_scale, + b_zero_point_, // uint8_t kernel_zero_point, + 1.0f, // float kernel_scale, + B_->Data(), // const uint8_t* kernel, + nullptr, // const int32_t* bias, + 0, // uint8_t output_zero_point, + 1.0f, // float output_scale, + 0, + UINT8_MAX, + flags, +#ifdef XNN_CACHE_ENABLE + GetCodeCache(), + GetWeightsCache(), +#else + nullptr, + nullptr, +#endif + &p); + + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_create_fully_connected_nc_qu8 returned ", status); + + op0_.reset(p); + } + + is_packed = true; + + return Status::OK(); +} + +Status MatMulInteger::Compute(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + pthreadpool_t threadpool = GetThreadPool(); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_)); + Tensor* y = ctx->Output(0, helper.OutputShape()); + + if (y->Shape().Size() == 0) + return Status::OK(); + + xnn_status status = xnn_status::xnn_status_uninitialized; + + auto a_shape = a->Shape(); + + if (a_shape.NumDimensions() != b_shape_.NumDimensions()) { + // A is [batch, ..., K] and B is [K, N] output is [batch, ..., N] + size_t batch_size = a_shape[0]; + size_t M = a_shape[1]; + for (size_t i = 0; i < batch_size; i++) { + size_t offset = i * M; + + status = xnn_reshape_fully_connected_nc_qs8(op0_.get(), M, threadpool); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_reshape_fully_connected_nc_qs8 returned ", status); + + status = xnn_setup_fully_connected_nc_qs8(op0_.get(), a->Data() + offset, y->MutableData() + offset); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_setup_fully_connected_nc_qs8 returned ", status); + + status = xnn_run_operator(op0_.get(), nullptr); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_run_operator returned ", status); + } + } else { + // A is [M, K] and B is [K, N] + status = xnn_reshape_fully_connected_nc_qs8(op0_.get(), a_shape[0], threadpool); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_reshape_fully_connected_nc_qs8 returned ", status); + + status = xnn_setup_fully_connected_nc_qs8(op0_.get(), a->Data(), y->MutableData()); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_setup_fully_connected_nc_qs8 returned ", status); + + status = xnn_run_operator(op0_.get(), nullptr); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_run_operator returned ", status); + } + + return Status::OK(); +} + +Status MatMulInteger::Compute(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + pthreadpool_t threadpool = GetThreadPool(); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape_)); + Tensor* y = ctx->Output(0, helper.OutputShape()); + + if (y->Shape().Size() == 0) + return Status::OK(); + + xnn_status status = xnn_status::xnn_status_uninitialized; + + auto a_shape = a->Shape(); + + if (a_shape.NumDimensions() != b_shape_.NumDimensions()) { + // A is [batch, M, K] and B is [K, N] + size_t batch_size = a_shape[0]; + size_t M = a_shape[1]; + for (size_t i = 0; i < batch_size; i++) { + size_t offset = i * M; + + status = xnn_reshape_fully_connected_nc_qu8(op0_.get(), M, threadpool); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_reshape_fully_connected_nc_qu8 returned ", status); + + status = xnn_setup_fully_connected_nc_qu8(op0_.get(), a->Data() + offset, y->MutableData() + offset); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_setup_fully_connected_nc_qu8 returned ", status); + + status = xnn_run_operator(op0_.get(), nullptr); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_run_operator returned ", status); + } + } else { + // A is [M, K] and B is [K, N] + status = xnn_reshape_fully_connected_nc_qu8(op0_.get(), a_shape[0], threadpool); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_reshape_fully_connected_nc_qu8 returned ", status); + + status = xnn_setup_fully_connected_nc_qu8(op0_.get(), a->Data(), y->MutableData()); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_setup_fully_connected_nc_qu8 returned ", status); + + status = xnn_run_operator(op0_.get(), nullptr); + ORT_RETURN_IF_NOT(status == xnn_status_success, "xnn_run_operator returned ", status); + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulInteger, + kOnnxDomain, + 10, + uint8_t, + kXnnpackExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMulInteger); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulInteger, + kOnnxDomain, + 10, + int8_t, + kXnnpackExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + MatMulInteger); + +} // namespace xnnpack +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/math/matmul_int.h b/onnxruntime/core/providers/xnnpack/math/matmul_int.h new file mode 100644 index 0000000000000..26fa3db7fa157 --- /dev/null +++ b/onnxruntime/core/providers/xnnpack/math/matmul_int.h @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/xnnpack/xnnpack_kernel.h" +#include "core/framework/allocator.h" +#include "core/providers/xnnpack/detail/utils.h" +#include "core/common/common.h" +#include "core/util/math.h" + +namespace onnxruntime { +class GraphViewer; +class Node; +namespace xnnpack { + +struct MatMulIntegerCommon { + // Required for checking XNNpack restrictions on ORT side + static bool IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& graph); +}; + +template +class MatMulInteger : public XnnpackKernel { + using MatType = T; + public: + MatMulInteger(const OpKernelInfo& info): XnnpackKernel(info, /*enable_caches*/ true) { + if (info.GetInputCount() > 2) { + has_a_zero_point_ = true; + } + if (info.GetInputCount() > 3) { + has_b_zero_point_ = true; + } + } + + Status Compute(OpKernelContext* /*context*/) const override; + + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + private: + AllocatorPtr myAlloc; + + bool has_a_zero_point_ = false; + bool has_b_zero_point_ = false; + + TensorShape b_shape_; + const Tensor* B_{nullptr}; + + MatType a_zero_point_ = 0; + MatType b_zero_point_ = 0; + + XnnpackOperator op0_ = nullptr; +}; + +} // namespace xnnpack +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index 12e567e7080b3..63e1cee42f656 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -131,6 +131,9 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO_VERSIONED(9, 12, MatMul, kOnnxDomain), KERNEL_CREATE_INFO(13, MatMul, kOnnxDomain), + KERNEL_CREATE_INFO_TYPED(10, int8_t, MatMulInteger, kOnnxDomain), + KERNEL_CREATE_INFO_TYPED(10, uint8_t, MatMulInteger, kOnnxDomain), + // quantization op KERNEL_CREATE_INFO(1, QLinearAveragePool, kMSInternalNHWCDomain),