-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[NNAPI EP] Make partitioning stop ops configurable. (#8444)
Enable NNAPI EP partitioning stop ops to be overridden by a session configuration option.
- Loading branch information
Showing
6 changed files
with
160 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include <string_view> | ||
#include <vector> | ||
|
||
#include "core/common/common.h" | ||
|
||
namespace onnxruntime { | ||
namespace utils { | ||
|
||
/** | ||
* Splits a string into substrings delimited by the given delimiter string. | ||
* @param string_to_split The string to split. | ||
* @param delimiter The delimiter string. | ||
* @param keep_empty Whether to keep empty substrings. | ||
* @return The split substrings. | ||
*/ | ||
inline std::vector<std::string_view> SplitString(std::string_view string_to_split, std::string_view delimiter, | ||
bool keep_empty = false) { | ||
ORT_ENFORCE(!delimiter.empty(), "delimiter must not be empty"); | ||
std::vector<std::string_view> result{}; | ||
std::string_view::size_type segment_begin_pos = 0; | ||
while (segment_begin_pos != std::string_view::npos) { | ||
const std::string_view::size_type segment_end_pos = string_to_split.find(delimiter, segment_begin_pos); | ||
const bool is_segment_empty = segment_begin_pos == segment_end_pos || segment_begin_pos == string_to_split.size(); | ||
if (!is_segment_empty || keep_empty) { | ||
result.push_back(string_to_split.substr(segment_begin_pos, segment_end_pos - segment_begin_pos)); | ||
} | ||
segment_begin_pos = (segment_end_pos == std::string_view::npos) | ||
? segment_end_pos | ||
: segment_end_pos + delimiter.size(); | ||
} | ||
return result; | ||
} | ||
|
||
} // namespace utils | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 37 additions & 9 deletions
46
onnxruntime/core/providers/nnapi/nnapi_provider_factory.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,59 @@ | ||
// Copyright 2019 JD.com Inc. JD AI | ||
|
||
#include "core/providers/nnapi/nnapi_provider_factory.h" | ||
#include "core/session/abi_session_options_impl.h" | ||
#include "nnapi_builtin/nnapi_execution_provider.h" | ||
|
||
using namespace onnxruntime; | ||
#include "core/common/optional.h" | ||
#include "core/common/string_utils.h" | ||
#include "core/providers/nnapi/nnapi_builtin/nnapi_execution_provider.h" | ||
#include "core/session/abi_session_options_impl.h" | ||
#include "core/session/onnxruntime_session_options_config_keys.h" | ||
|
||
namespace onnxruntime { | ||
|
||
namespace { | ||
struct NnapiProviderFactory : IExecutionProviderFactory { | ||
NnapiProviderFactory(uint32_t nnapi_flags) | ||
: nnapi_flags_(nnapi_flags) {} | ||
NnapiProviderFactory(uint32_t nnapi_flags, | ||
const optional<std::unordered_set<std::string>>& partitioning_stop_ops) | ||
: nnapi_flags_(nnapi_flags), | ||
partitioning_stop_ops_(partitioning_stop_ops) {} | ||
|
||
~NnapiProviderFactory() override {} | ||
|
||
std::unique_ptr<IExecutionProvider> CreateProvider() override; | ||
uint32_t nnapi_flags_; | ||
|
||
private: | ||
const uint32_t nnapi_flags_; | ||
const optional<std::unordered_set<std::string>> partitioning_stop_ops_; | ||
}; | ||
|
||
std::unique_ptr<IExecutionProvider> NnapiProviderFactory::CreateProvider() { | ||
return std::make_unique<NnapiExecutionProvider>(nnapi_flags_); | ||
return std::make_unique<NnapiExecutionProvider>(nnapi_flags_, partitioning_stop_ops_); | ||
} | ||
|
||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi_Internal( | ||
uint32_t nnapi_flags, const optional<std::unordered_set<std::string>>& partitioning_stop_ops) { | ||
return std::make_shared<NnapiProviderFactory>(nnapi_flags, partitioning_stop_ops); | ||
} | ||
} // namespace | ||
|
||
std::shared_ptr<IExecutionProviderFactory> CreateExecutionProviderFactory_Nnapi(uint32_t nnapi_flags) { | ||
return std::make_shared<onnxruntime::NnapiProviderFactory>(nnapi_flags); | ||
return CreateExecutionProviderFactory_Nnapi_Internal(nnapi_flags, nullopt); | ||
} | ||
|
||
} // namespace onnxruntime | ||
|
||
ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Nnapi, _In_ OrtSessionOptions* options, uint32_t nnapi_flags) { | ||
options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_Nnapi(nnapi_flags)); | ||
const auto partitioning_stop_ops = [&]() -> onnxruntime::optional<std::unordered_set<std::string>> { | ||
if (std::string partitioning_stop_ops_value{}; | ||
options->value.config_options.TryGetConfigEntry(kOrtSessionOptionsConfigNnapiEpPartitioningStopOps, | ||
partitioning_stop_ops_value)) { | ||
const auto partitioning_stop_ops_list = onnxruntime::utils::SplitString(partitioning_stop_ops_value, ","); | ||
return std::unordered_set<std::string>(partitioning_stop_ops_list.begin(), partitioning_stop_ops_list.end()); | ||
} | ||
return onnxruntime::nullopt; | ||
}(); | ||
|
||
options->provider_factories.push_back( | ||
onnxruntime::CreateExecutionProviderFactory_Nnapi_Internal(nnapi_flags, partitioning_stop_ops)); | ||
return nullptr; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters