Skip to content

Commit

Permalink
[NNAPI EP] Make partitioning stop ops configurable. (#8444)
Browse files Browse the repository at this point in the history
Enable NNAPI EP partitioning stop ops to be overridden by a session configuration option.
  • Loading branch information
edgchen1 authored Jul 22, 2021
1 parent 892ac9f commit 989491c
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,16 @@ static const char* const kOrtSessionOptionsUseDeviceAllocatorForInitializers = "
// "1": default, thread will spin a number of times before blocking
static const char* const kOrtSessionOptionsConfigAllowInterOpSpinning = "session.inter_op.allow_spinning";
static const char* const kOrtSessionOptionsConfigAllowIntraOpSpinning = "session.intra_op.allow_spinning";

// NNAPI EP keys begin
// Note: These options should be specified prior to appending the NNAPI EP to the session options object in order for
// them to take effect.

// Specifies a list of stop op types. Nodes of a type in the stop op types and nodes downstream from them will not be
// run by the NNAPI EP.
// The value should be a ","-delimited list of op types. For example, "Add,Sub".
// If not specified, the default set of stop ops is used. To specify an empty stop ops types list and disable stop op
// exclusion, set the value to "".
static const char* const kOrtSessionOptionsConfigNnapiEpPartitioningStopOps = "ep.nnapi.partitioning_stop_ops";

// NNAPI EP keys end
40 changes: 40 additions & 0 deletions onnxruntime/core/common/string_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <string_view>
#include <vector>

#include "core/common/common.h"

namespace onnxruntime {
namespace utils {

/**
* Splits a string into substrings delimited by the given delimiter string.
* @param string_to_split The string to split.
* @param delimiter The delimiter string.
* @param keep_empty Whether to keep empty substrings.
* @return The split substrings.
*/
inline std::vector<std::string_view> SplitString(std::string_view string_to_split, std::string_view delimiter,
bool keep_empty = false) {
ORT_ENFORCE(!delimiter.empty(), "delimiter must not be empty");
std::vector<std::string_view> result{};
std::string_view::size_type segment_begin_pos = 0;
while (segment_begin_pos != std::string_view::npos) {
const std::string_view::size_type segment_end_pos = string_to_split.find(delimiter, segment_begin_pos);
const bool is_segment_empty = segment_begin_pos == segment_end_pos || segment_begin_pos == string_to_split.size();
if (!is_segment_empty || keep_empty) {
result.push_back(string_to_split.substr(segment_begin_pos, segment_end_pos - segment_begin_pos));
}
segment_begin_pos = (segment_end_pos == std::string_view::npos)
? segment_end_pos
: segment_end_pos + delimiter.size();
}
return result;
}

} // namespace utils
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

#include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h"

#include "core/common/string_utils.h"
#include "core/framework/allocatormgr.h"
#include "core/framework/compute_capability.h"
#include "core/graph/graph_viewer.h"
#include "core/platform/env.h"
#include "core/providers/common.h"
#include "core/providers/nnapi/nnapi_builtin/builders/helper.h"
#include "core/providers/nnapi/nnapi_builtin/builders/op_support_checker.h"
Expand All @@ -20,17 +22,31 @@

namespace onnxruntime {

namespace {

constexpr const char* NNAPI = "Nnapi";

constexpr std::array kDefaultPartitioningStopOps{
"NonMaxSuppression",
};

NnapiExecutionProvider::NnapiExecutionProvider(uint32_t nnapi_flags)
std::unordered_set<std::string> GetPartitioningStopOps(const optional<std::unordered_set<std::string>>& partitioning_stop_ops) {
if (!partitioning_stop_ops.has_value()) {
LOGS_DEFAULT(VERBOSE) << "Using default partitioning stop ops list.";
return std::unordered_set<std::string>(kDefaultPartitioningStopOps.begin(), kDefaultPartitioningStopOps.end());
}

LOGS_DEFAULT(INFO) << "Using partitioning stop ops list from configuration.";
return partitioning_stop_ops.value();
}

} // namespace

NnapiExecutionProvider::NnapiExecutionProvider(uint32_t nnapi_flags,
const optional<std::unordered_set<std::string>>& partitioning_stop_ops)
: IExecutionProvider{onnxruntime::kNnapiExecutionProvider, true},
nnapi_flags_(nnapi_flags),
// TODO make this configurable
partitioning_stop_ops_(kDefaultPartitioningStopOps.begin(), kDefaultPartitioningStopOps.end()) {
partitioning_stop_ops_(GetPartitioningStopOps(partitioning_stop_ops)) {
AllocatorCreationInfo device_info(
[](int) {
return std::make_unique<CPUAllocator>(OrtMemoryInfo(NNAPI, OrtAllocatorType::OrtDeviceAllocator));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include "core/common/optional.h"
#include "core/framework/execution_provider.h"
#include "core/providers/nnapi/nnapi_provider_factory.h"

Expand All @@ -13,7 +14,9 @@ class Model;

class NnapiExecutionProvider : public IExecutionProvider {
public:
NnapiExecutionProvider(uint32_t nnapi_flags);
explicit NnapiExecutionProvider(uint32_t nnapi_flags,
const optional<std::unordered_set<std::string>>& partitioning_stop_ops = nullopt);

virtual ~NnapiExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
46 changes: 37 additions & 9 deletions onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc
Original file line number Diff line number Diff line change
@@ -1,31 +1,59 @@
// Copyright 2019 JD.com Inc. JD AI

#include "core/providers/nnapi/nnapi_provider_factory.h"
#include "core/session/abi_session_options_impl.h"
#include "nnapi_builtin/nnapi_execution_provider.h"

using namespace onnxruntime;
#include "core/common/optional.h"
#include "core/common/string_utils.h"
#include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h"
#include "core/session/abi_session_options_impl.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

namespace onnxruntime {

namespace {
struct NnapiProviderFactory : IExecutionProviderFactory {
NnapiProviderFactory(uint32_t nnapi_flags)
: nnapi_flags_(nnapi_flags) {}
NnapiProviderFactory(uint32_t nnapi_flags,
const optional<std::unordered_set<std::string>>& partitioning_stop_ops)
: nnapi_flags_(nnapi_flags),
partitioning_stop_ops_(partitioning_stop_ops) {}

~NnapiProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;
uint32_t nnapi_flags_;

private:
const uint32_t nnapi_flags_;
const optional<std::unordered_set<std::string>> partitioning_stop_ops_;
};

std::unique_ptr<IExecutionProvider> NnapiProviderFactory::CreateProvider() {
return std::make_unique<NnapiExecutionProvider>(nnapi_flags_);
return std::make_unique<NnapiExecutionProvider>(nnapi_flags_, partitioning_stop_ops_);
}

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi_Internal(
uint32_t nnapi_flags, const optional<std::unordered_set<std::string>>& partitioning_stop_ops) {
return std::make_shared<NnapiProviderFactory>(nnapi_flags, partitioning_stop_ops);
}
} // namespace

std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(uint32_t nnapi_flags) {
return std::make_shared<onnxruntime::NnapiProviderFactory>(nnapi_flags);
return CreateExecutionProviderFactory_Nnapi_Internal(nnapi_flags, nullopt);
}

} // namespace onnxruntime

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, uint32_t nnapi_flags) {
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi(nnapi_flags));
const auto partitioning_stop_ops = [&]() -> onnxruntime::optional<std::unordered_set<std::string>> {
if (std::string partitioning_stop_ops_value{};
options->value.config_options.TryGetConfigEntry(kOrtSessionOptionsConfigNnapiEpPartitioningStopOps,
partitioning_stop_ops_value)) {
const auto partitioning_stop_ops_list = onnxruntime::utils::SplitString(partitioning_stop_ops_value, ",");
return std::unordered_set<std::string>(partitioning_stop_ops_list.begin(), partitioning_stop_ops_list.end());
}
return onnxruntime::nullopt;
}();

options->provider_factories.push_back(
onnxruntime::CreateExecutionProviderFactory_Nnapi_Internal(nnapi_flags, partitioning_stop_ops));
return nullptr;
}
47 changes: 47 additions & 0 deletions onnxruntime/test/common/string_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

#include "core/common/make_string.h"
#include "core/common/parse_string.h"
#include "core/common/string_utils.h"

#include <algorithm>

#include "gtest/gtest.h"

Expand Down Expand Up @@ -89,5 +92,49 @@ TEST(StringUtilsTest, MakeStringAndTryParseStringWithCustomType) {
ASSERT_EQ(parsed_s, s);
}

TEST(StringUtilsTest, SplitString) {
auto run_test = [](const std::string& string_to_split, const std::string& delimiter,
const std::vector<std::string>& expected_substrings_with_empty) {
SCOPED_TRACE(MakeString("string_to_split: \"", string_to_split, "\", delimiter: \"", delimiter, "\""));

auto test_split = [&](const std::vector<std::string>& expected_substrings, bool keep_empty) {
SCOPED_TRACE(MakeString("keep_empty: ", keep_empty));

const auto actual_substrings = utils::SplitString(string_to_split, delimiter, keep_empty);
ASSERT_EQ(actual_substrings.size(), expected_substrings.size());
for (size_t i = 0; i < actual_substrings.size(); ++i) {
EXPECT_EQ(actual_substrings[i], expected_substrings[i]) << "i=" << i;
}
};

test_split(expected_substrings_with_empty, true);

const std::vector<std::string> expected_substrings_without_empty = [&]() {
std::vector<std::string> result = expected_substrings_with_empty;
result.erase(std::remove_if(result.begin(), result.end(),
[](const std::string& value) { return value.empty(); }),
result.end());
return result;
}();
test_split(expected_substrings_without_empty, false);
};

run_test("a,b,c", ",", {"a", "b", "c"});
run_test(",a,,b,,,c,", ",", {"", "a", "", "b", "", "", "c", ""});
run_test("one_delimiter_two_delimiter_", "_delimiter_", {"one", "two", ""});
run_test("aaaaaaa", "aa", {"", "", "", "a"});
run_test("abcabaabc", "abc", {"", "aba", ""});
run_test("leading,", ",", {"leading", ""});
run_test(",trailing", ",", {"", "trailing"});
run_test("", ",", {""});
run_test(",", ",", {"", ""});
}

#ifndef ORT_NO_EXCEPTIONS
TEST(StringUtilsTest, SplitStringWithEmptyDelimiter) {
EXPECT_THROW(utils::SplitString("a", ""), OnnxRuntimeException);
}
#endif

} // namespace test
} // namespace onnxruntime

0 comments on commit 989491c

Please sign in to comment.