Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 717950001
  • Loading branch information
MediaPipe Team authored and copybara-github committed Jan 21, 2025
1 parent cf6e4ec commit 7769ae3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build() {

XnnSubgraphPtr subgraph{subgraph_ptr, xnn_delete_subgraph};

for (auto& t : static_weights_) {
for (auto& t : static_weights_added_order_) {
MP_RETURN_IF_ERROR(t->DefineWeight(*subgraph_ptr));
}
for (auto& input : input_tensors_added_order_) {
Expand All @@ -190,7 +190,6 @@ absl::StatusOr<std::unique_ptr<XnnGraph>> XnnGraphBuilder::Build() {
std::make_unique<RuntimeConfigs>(*runtime_configs_));
result.input_tensors_ = std::move(input_tensors_added_order_);
result.output_tensors_ = std::move(output_tensors);
result.static_weights_ = std::move(static_weights_);

VLOG(2) << "XnnGraphBuilder::Build() creating runtime...";
MP_RETURN_IF_ERROR(result.CreateRuntime());
Expand Down Expand Up @@ -218,10 +217,12 @@ absl::Status XnnGraphBuilder::MarkInput(std::shared_ptr<Tensor> t) {
}

void XnnGraphBuilder::NewWeight(std::shared_ptr<Tensor> t) {
if (interm_tensors_.contains(t) || input_tensors_.contains(t)) {
if (interm_tensors_.contains(t) || input_tensors_.contains(t) ||
static_weights_.contains(t)) {
return;
}

static_weights_added_order_.push_back(t);
static_weights_.insert(t);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ class XnnGraphBuilder {
std::vector<std::shared_ptr<Tensor>> interm_tensors_added_order_;
// Intermediate tensors in hash_set, for easy existence check.
absl::flat_hash_set<std::shared_ptr<Tensor>> interm_tensors_;

// Static weights keeping the same order as how they were added.
std::vector<std::shared_ptr<Tensor>> static_weights_added_order_;
absl::flat_hash_set<std::shared_ptr<Tensor>> static_weights_;

// Caches
Expand Down Expand Up @@ -370,8 +373,6 @@ class XnnGraph {

std::vector<std::shared_ptr<Tensor>> input_tensors_;
std::vector<std::shared_ptr<Tensor>> output_tensors_;

absl::flat_hash_set<std::shared_ptr<Tensor>> static_weights_;
};

} // namespace xnn_utils
Expand Down

0 comments on commit 7769ae3

Please sign in to comment.