Skip to content

Commit

Permalink
Introduce Resources with mapping, so graphs can use placeholders inst…
Browse files Browse the repository at this point in the history
…ead of actual resource paths.

PiperOrigin-RevId: 675346445
  • Loading branch information
MediaPipe Team authored and copybara-github committed Sep 17, 2024
1 parent 502f396 commit feb192b
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 15 deletions.
3 changes: 3 additions & 0 deletions mediapipe/framework/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ cc_library(
deps = [
"//mediapipe/framework/port:status",
"//mediapipe/util:resource_util",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
Expand All @@ -684,6 +685,7 @@ cc_test(
":resources",
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status_matchers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
Expand Down Expand Up @@ -721,6 +723,7 @@ cc_test(
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:status",
"//mediapipe/framework/port:status_matchers",
"//mediapipe/framework/testdata:resource_path_cc_proto",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
65 changes: 51 additions & 14 deletions mediapipe/framework/calculator_graph_resources_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

#include <memory>
#include <string>
#include <utility>
Expand All @@ -21,6 +20,7 @@
#include "mediapipe/framework/port/status_matchers.h"
#include "mediapipe/framework/resources.h"
#include "mediapipe/framework/resources_service.h"
#include "mediapipe/framework/testdata/resource_path.pb.h"

namespace mediapipe {
namespace {
Expand All @@ -46,15 +46,19 @@ class TestResourcesCalculator : public Node {
MEDIAPIPE_NODE_CONTRACT(kSideOut, kOut);

absl::Status Open(CalculatorContext* cc) override {
MP_ASSIGN_OR_RETURN(std::unique_ptr<Resource> resource,
cc->GetResources().Get(kCalculatorResource));
MP_ASSIGN_OR_RETURN(
std::unique_ptr<Resource> resource,
cc->GetResources().Get(
cc->Options<mediapipe::ResourcePathOptions>().path()));
kSideOut(cc).Set(api2::PacketAdopting(std::move(resource)));
return absl::OkStatus();
}

absl::Status Process(CalculatorContext* cc) override {
MP_ASSIGN_OR_RETURN(std::unique_ptr<Resource> resource,
cc->GetResources().Get(kCalculatorResource));
MP_ASSIGN_OR_RETURN(
std::unique_ptr<Resource> resource,
cc->GetResources().Get(
cc->Options<mediapipe::ResourcePathOptions>().path()));
kOut(cc).Send(std::move(resource));
return tool::StatusStop();
}
Expand All @@ -65,8 +69,10 @@ class TestResourcesSubgraph : public Subgraph {
public:
absl::StatusOr<CalculatorGraphConfig> GetConfig(
SubgraphContext* sc) override {
MP_ASSIGN_OR_RETURN(std::unique_ptr<Resource> resource,
sc->GetResources().Get(kSubgraphResource));
MP_ASSIGN_OR_RETURN(
std::unique_ptr<Resource> resource,
sc->GetResources().Get(
sc->Options<mediapipe::ResourcePathOptions>().path()));
Graph graph;
auto& constants_node = graph.AddNode("ConstantSidePacketCalculator");
auto& constants_options =
Expand All @@ -90,13 +96,18 @@ struct ResourceContentsPackets {
Packet calculator_side_out;
};

CalculatorGraphConfig BuildGraphProducingResourceContentsPackets() {
CalculatorGraphConfig BuildGraphProducingResourceContentsPackets(
absl::string_view calculator_path, absl::string_view subgraph_path) {
Graph graph;

auto& subgraph = graph.AddNode("TestResourcesSubgraph");
subgraph.GetOptions<mediapipe::ResourcePathOptions>().set_path(
std::string(subgraph_path));
subgraph.SideOut("SIDE_OUT").SetName("subgraph_side_out");

auto& calculator = graph.AddNode("TestResourcesCalculator");
calculator.GetOptions<mediapipe::ResourcePathOptions>().set_path(
std::string(calculator_path));
calculator.SideOut("SIDE_OUT").SetName("calculator_side_out");
calculator.Out("OUT").SetName("calculator_out");

Expand Down Expand Up @@ -129,8 +140,9 @@ RunGraphAndCollectResourceContentsPackets(CalculatorGraph& calculator_graph) {

TEST(CalculatorGraphResourcesTest, GraphAndContextsHaveDefaultResources) {
CalculatorGraph calculator_graph;
MP_ASSERT_OK(calculator_graph.Initialize(
BuildGraphProducingResourceContentsPackets()));
MP_ASSERT_OK(
calculator_graph.Initialize(BuildGraphProducingResourceContentsPackets(
kCalculatorResource, kSubgraphResource)));
MP_ASSERT_OK_AND_ASSIGN(
ResourceContentsPackets packets,
RunGraphAndCollectResourceContentsPackets(calculator_graph));
Expand Down Expand Up @@ -189,8 +201,9 @@ TEST(CalculatorGraphResourcesTest, CustomResourcesCanBeSetOnGraph) {
std::shared_ptr<Resources> resources = std::make_shared<CustomResources>();
MP_ASSERT_OK(calculator_graph.SetServiceObject(kResourcesService,
std::move(resources)));
MP_ASSERT_OK(calculator_graph.Initialize(
BuildGraphProducingResourceContentsPackets()));
MP_ASSERT_OK(
calculator_graph.Initialize(BuildGraphProducingResourceContentsPackets(
kCalculatorResource, kSubgraphResource)));
MP_ASSERT_OK_AND_ASSIGN(
ResourceContentsPackets packets,
RunGraphAndCollectResourceContentsPackets(calculator_graph));
Expand Down Expand Up @@ -234,8 +247,9 @@ TEST(CalculatorGraphResourcesTest,
std::make_shared<CustomizedDefaultResources>();
MP_ASSERT_OK(calculator_graph.SetServiceObject(kResourcesService,
std::move(resources)));
MP_ASSERT_OK(calculator_graph.Initialize(
BuildGraphProducingResourceContentsPackets()));
MP_ASSERT_OK(
calculator_graph.Initialize(BuildGraphProducingResourceContentsPackets(
kCalculatorResource, kSubgraphResource)));
MP_ASSERT_OK_AND_ASSIGN(
ResourceContentsPackets packets,
RunGraphAndCollectResourceContentsPackets(calculator_graph));
Expand All @@ -248,5 +262,28 @@ TEST(CalculatorGraphResourcesTest,
"Customized: File system calculator contents\n");
}

TEST(CalculatorGraphResourcesTest,
DefaultResourcesWithMappingCanBeSetAndUsedOnGraph) {
CalculatorGraph calculator_graph;
std::shared_ptr<Resources> resources = CreateDefaultResourcesWithMapping(
{{"$CALCULATOR_PATH", std::string(kCalculatorResource)},
{"$SUBGRAPH_PATH", std::string(kSubgraphResource)}});
MP_ASSERT_OK(calculator_graph.SetServiceObject(kResourcesService,
std::move(resources)));
MP_ASSERT_OK(
calculator_graph.Initialize(BuildGraphProducingResourceContentsPackets(
"$CALCULATOR_PATH", "$SUBGRAPH_PATH")));
MP_ASSERT_OK_AND_ASSIGN(
ResourceContentsPackets packets,
RunGraphAndCollectResourceContentsPackets(calculator_graph));

EXPECT_EQ(packets.subgraph_side_out.Get<std::string>(),
"File system subgraph contents\n");
EXPECT_EQ(packets.calculator_out.Get<Resource>().ToStringView(),
"File system calculator contents\n");
EXPECT_EQ(packets.calculator_side_out.Get<Resource>().ToStringView(),
"File system calculator contents\n");
}

} // namespace
} // namespace mediapipe
50 changes: 50 additions & 0 deletions mediapipe/framework/resources.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <string>
#include <utility>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -47,6 +48,42 @@ class DefaultResources : public Resources {
}
};

class ResourcesWithMapping : public Resources {
public:
explicit ResourcesWithMapping(
std::unique_ptr<Resources> resources,
absl::flat_hash_map<std::string, std::string> mapping)
: resources_(std::move(resources)), mapping_(std::move(mapping)) {}

absl::Status ReadContents(absl::string_view resource_id, std::string& output,
const Options& options) const final {
auto iter = mapping_.find(resource_id);
absl::string_view resolved_res_id;
if (iter != mapping_.end()) {
resolved_res_id = iter->second;
} else {
resolved_res_id = resource_id;
}
return resources_->ReadContents(resolved_res_id, output, options);
}

absl::StatusOr<std::unique_ptr<Resource>> Get(
absl::string_view resource_id, const Options& options) const final {
auto iter = mapping_.find(resource_id);
absl::string_view resolved_res_id;
if (iter != mapping_.end()) {
resolved_res_id = iter->second;
} else {
resolved_res_id = resource_id;
}
return resources_->Get(resolved_res_id, options);
}

private:
std::unique_ptr<Resources> resources_;
absl::flat_hash_map<std::string, std::string> mapping_;
};

} // namespace

std::unique_ptr<Resource> MakeStringResource(std::string&& s) {
Expand All @@ -63,4 +100,17 @@ std::unique_ptr<Resources> CreateDefaultResources() {
return std::make_unique<DefaultResources>();
}

std::unique_ptr<Resources> CreateDefaultResourcesWithMapping(
absl::flat_hash_map<std::string, std::string> mapping) {
return CreateResourcesWithMapping(CreateDefaultResources(),
std::move(mapping));
}

std::unique_ptr<Resources> CreateResourcesWithMapping(
std::unique_ptr<Resources> resources,
absl::flat_hash_map<std::string, std::string> mapping) {
return std::make_unique<ResourcesWithMapping>(std::move(resources),
std::move(mapping));
}

} // namespace mediapipe
39 changes: 39 additions & 0 deletions mediapipe/framework/resources.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <memory>
#include <string>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -122,6 +123,44 @@ class Resources {
// `Resources` object which can be used in place of `GetResourceContents`.
std::unique_ptr<Resources> CreateDefaultResources();

// Creates `Resources` object which enables resource mapping within a graph and
// can be used in place of `GetResourceContents`.
//
// `mapping` keys are resources ids.
//
// Example:
//
// `CalculatorGraphConfig`:
// node {
// ...
// options {
// [type.googleapis.com/...] {
// model_path: "$MODEL"
// }
// }
// }
//
// `CalculatorGraph` setup:
//
// CalculatorGraph graph;
// std::shared_ptr<Resources> resources = CreateDefaultResourcesWithMapping(
// {{"$MODEL", "real/path/to/the/model"}});
// graph.SetServiceObject(kResourcesService, std::move(resources));
// graph.Initialize(std::move(config));
//
// As a result, when loading using ...Context::GetResources, not will be able
// to load the model from "real/path/to/the/model".
std::unique_ptr<Resources> CreateDefaultResourcesWithMapping(
absl::flat_hash_map<std::string, std::string> mapping);

// Wraps `resources` to provide resources by resource id using a mapping when
// available.
//
// `mapping` keys are resources ids.
std::unique_ptr<Resources> CreateResourcesWithMapping(
std::unique_ptr<Resources> resources,
absl::flat_hash_map<std::string, std::string> mapping);

} // namespace mediapipe

#endif // MEDIAPIPE_FRAMEWORK_RESOURCES_H_
42 changes: 41 additions & 1 deletion mediapipe/framework/resources_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

#include <memory>
#include <string>
#include <utility>

#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
Expand All @@ -24,7 +26,7 @@ TEST(Resources, CanCreateNoCleanupResource) {
EXPECT_EQ(resource->ToStringView(), "Test string.");
}

TEST(Resources, CanCreateDefaultResourcesThatAndReadFileContents) {
TEST(Resources, CanCreateDefaultResourcesAndReadFileContents) {
std::unique_ptr<Resources> resources = CreateDefaultResources();

std::string contents;
Expand All @@ -38,6 +40,21 @@ TEST(Resources, CanCreateDefaultResourcesThatAndReadFileContents) {
EXPECT_EQ(resource->ToStringView(), "File system calculator contents\n");
}

TEST(Resources, CanCreateDefaultResourcesWithMappingAndReadFileContents) {
absl::flat_hash_map<std::string, std::string> mapping = {
{"$CUSTOM_ID", "mediapipe/framework/testdata/resource_calculator.data"}};
std::unique_ptr<Resources> resources =
CreateDefaultResourcesWithMapping(std::move(mapping));

std::string contents;
MP_ASSERT_OK(resources->ReadContents("$CUSTOM_ID", contents));
EXPECT_EQ(contents, "File system calculator contents\n");

MP_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Resource> resource,
resources->Get("$CUSTOM_ID"));
EXPECT_EQ(resource->ToStringView(), "File system calculator contents\n");
}

class CustomResources : public Resources {
public:
absl::Status ReadContents(absl::string_view resource_id, std::string& output,
Expand Down Expand Up @@ -80,5 +97,28 @@ TEST(Resources, CanCreateCustomResourcesAndReuseDefault) {
EXPECT_EQ(resource->ToStringView(), "Custom content.");
}

TEST(Resources, CanCreateCustomResourcesAndUseMapping) {
std::unique_ptr<Resources> resources = std::make_unique<CustomResources>();
absl::flat_hash_map<std::string, std::string> mapping = {
{"$CUSTOM_ID", "custom/resource/id"}};
resources =
CreateResourcesWithMapping(std::move(resources), std::move(mapping));

std::string contents;
MP_ASSERT_OK(resources->ReadContents(
"mediapipe/framework/testdata/resource_calculator.data", contents));
EXPECT_EQ(contents, "File system calculator contents\n");
MP_ASSERT_OK(resources->ReadContents("$CUSTOM_ID", contents));
EXPECT_EQ(contents, "Custom content.");

std::unique_ptr<Resource> resource;
MP_ASSERT_OK_AND_ASSIGN(
resource,
resources->Get("mediapipe/framework/testdata/resource_calculator.data"));
EXPECT_EQ(resource->ToStringView(), "File system calculator contents\n");
MP_ASSERT_OK_AND_ASSIGN(resource, resources->Get("$CUSTOM_ID"));
EXPECT_EQ(resource->ToStringView(), "Custom content.");
}

} // namespace
} // namespace mediapipe
6 changes: 6 additions & 0 deletions mediapipe/framework/testdata/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ licenses(["notice"])

package(default_visibility = ["//visibility:private"])

mediapipe_proto_library(
name = "resource_path_proto",
srcs = ["resource_path.proto"],
visibility = ["//visibility:public"],
)

mediapipe_proto_library(
name = "sky_light_calculator_proto",
srcs = ["sky_light_calculator.proto"],
Expand Down
26 changes: 26 additions & 0 deletions mediapipe/framework/testdata/resource_path.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright 2024 The MediaPipe Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Forked from mediapipe/framework/tool/source.proto.
// The forked proto must remain identical to the original proto and should be
// ONLY used by mediapipe open source project.

syntax = "proto3";

package mediapipe;

// A proto3 calculator options for testing.
message ResourcePathOptions {
string path = 1;
}

0 comments on commit feb192b

Please sign in to comment.