Skip to content

Commit

Permalink
Merge branch 'xnnpack' into v1.17.3-bricks
Browse files Browse the repository at this point in the history
  • Loading branch information
hans00 committed May 30, 2024
2 parents 304fb78 + 7cc905c commit b349298
Show file tree
Hide file tree
Showing 5 changed files with 422 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
// each operator provides a helper to check if supported
#include "core/providers/xnnpack/math/gemm.h"
#include "core/providers/xnnpack/math/matmul.h"
#include "core/providers/xnnpack/math/matmul_int.h"
#include "core/providers/xnnpack/math/softmax.h"
#include "core/providers/xnnpack/nn/average_pool.h"
#include "core/providers/xnnpack/nn/conv.h"
Expand Down
58 changes: 39 additions & 19 deletions onnxruntime/core/providers/xnnpack/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,30 @@ bool MatMul::IsOnnxNodeSupported(const NodeUnit& node_unit, const GraphViewer& g
// Support only float
const auto* A_type = A_arg.TypeAsProto();

if (A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
break;
}

const auto* A_shape = A_arg.Shape();
const auto* B_shape = B_arg.Shape();

if (A_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT) {
if (A_shape == nullptr || B_shape == nullptr) {
break;
}

size_t A_rank = A_shape->dim_size();
size_t B_rank = B_shape->dim_size();

// Support A [M, K] or [batch, M, K] x B [K, N] or [N]
if (B_rank > 2 || (A_rank != B_rank && A_rank != B_rank + 1)) {
break;
}

if (A_shape == nullptr || A_shape->dim_size() > 2 ||
(A_shape->dim_size() == 2 && A_shape->dim(1).dim_value() == 0) ||
A_shape->dim(0).dim_value() == 0) {
if (B_shape->dim(0).dim_value() == 0) {
break;
}

if (B_shape == nullptr || B_shape->dim_size() > 2 ||
(B_shape->dim_size() == 2 && B_shape->dim(1).dim_value() == 0) ||
B_shape->dim(0).dim_value() == 0) {
if (B_rank == 2 && B_shape->dim(1).dim_value() == 0) {
break;
}

Expand Down Expand Up @@ -128,20 +136,32 @@ Status MatMul::Compute(OpKernelContext* ctx) const {

auto* y_data = y->MutableData<float>();

xnn_status status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_fully_connected_nc_f32 returned ", status);
}

status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data<float>(), y_data);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_setup_fully_connected_nc_f32 returned ", status);
xnn_status status = xnn_status::xnn_status_uninitialized;
auto a_shape = a->Shape();

if (a_shape.NumDimensions() != b_shape_.NumDimensions()) {
// A is [batch, ..., K] and B is [K, N] output is [batch, ..., N]
size_t batch_size = a_shape[0];
size_t M = a_shape[1];
for (size_t i = 0; i < batch_size; i++) {
size_t offset = i * M;
status = xnn_reshape_fully_connected_nc_f32(op0_.get(), M, threadpool);
ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_reshape_fully_connected_nc_f32 returned ", status);
status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data<float>() + offset, y_data + offset);
ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_setup_fully_connected_nc_f32 returned ", status);
status = xnn_run_operator(op0_.get(), nullptr);
ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_run_operator returned ", status);
}
} else {
// A is [M, K] and B is [K, N]
status = xnn_reshape_fully_connected_nc_f32(op0_.get(), a->Shape()[0], threadpool);
ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_reshape_fully_connected_nc_f32 returned ", status);
status = xnn_setup_fully_connected_nc_f32(op0_.get(), a->Data<float>(), y_data);
ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_setup_fully_connected_nc_f32 returned ", status);
status = xnn_run_operator(op0_.get(), nullptr);
ORT_RETURN_IF_NOT(xnn_status_success == status, "xnn_run_operator returned ", status);
}

status = xnn_run_operator(op0_.get(), nullptr);
if (status != xnn_status_success) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_run_operator returned ", status);
}
return Status::OK();
}

Expand Down
Loading

0 comments on commit b349298

Please sign in to comment.