diff --git a/mediapipe/util/tflite/BUILD b/mediapipe/util/tflite/BUILD index b34d0e080f..a630f29743 100644 --- a/mediapipe/util/tflite/BUILD +++ b/mediapipe/util/tflite/BUILD @@ -13,7 +13,11 @@ # limitations under the License. # -load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") +load( + "@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", + "cc_library_with_tflite", + "cc_test_with_tflite", +) licenses(["notice"]) @@ -141,3 +145,27 @@ cc_library_with_tflite( "//mediapipe/util:resource_util", ], ) + +cc_test_with_tflite( + name = "tflite_model_loader_test", + srcs = ["tflite_model_loader_test.cc"], + data = [ + ":testdata/test_model.tflite", + ], + tflite_deps = [ + ":tflite_model_loader", + "@org_tensorflow//tensorflow/lite:test_util", + ], + deps = [ + "//mediapipe/framework:calculator_cc_proto", + "//mediapipe/framework:calculator_context", + "//mediapipe/framework:calculator_framework", + "//mediapipe/framework:calculator_state", + "//mediapipe/framework:legacy_calculator_support", + "//mediapipe/framework/api2:packet", + "//mediapipe/framework/port:gtest_main", + "//mediapipe/framework/tool:tag_map_helper", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/strings", + ], +) diff --git a/mediapipe/util/tflite/testdata/test_model.tflite b/mediapipe/util/tflite/testdata/test_model.tflite new file mode 100644 index 0000000000..b4c02350c0 Binary files /dev/null and b/mediapipe/util/tflite/testdata/test_model.tflite differ diff --git a/mediapipe/util/tflite/tflite_model_loader.cc b/mediapipe/util/tflite/tflite_model_loader.cc index 766543f9cd..83b725c66d 100644 --- a/mediapipe/util/tflite/tflite_model_loader.cc +++ b/mediapipe/util/tflite/tflite_model_loader.cc @@ -14,6 +14,9 @@ #include "mediapipe/util/tflite/tflite_model_loader.h" +#include +#include + #include "mediapipe/framework/port/ret_check.h" #include "mediapipe/util/resource_util.h" @@ -26,11 +29,10 @@ absl::StatusOr> TfLiteModelLoader::LoadFromPath( std::string model_path = path; std::string model_blob; - auto status_or_content = - mediapipe::GetResourceContents(model_path, &model_blob); + absl::Status status = mediapipe::GetResourceContents(model_path, &model_blob); // TODO: get rid of manual resolving with PathToResourceAsFile // as soon as it's incorporated into GetResourceContents. - if (!status_or_content.ok()) { + if (!status.ok()) { MP_ASSIGN_OR_RETURN(auto resolved_path, mediapipe::PathToResourceAsFile(model_path)); VLOG(2) << "Loading the model from " << resolved_path; @@ -40,6 +42,7 @@ absl::StatusOr> TfLiteModelLoader::LoadFromPath( auto model = FlatBufferModel::VerifyAndBuildFromBuffer(model_blob.data(), model_blob.size()); + RET_CHECK(model) << "Failed to load model from path " << model_path; return api2::MakePacket( model.release(), diff --git a/mediapipe/util/tflite/tflite_model_loader_test.cc b/mediapipe/util/tflite/tflite_model_loader_test.cc new file mode 100644 index 0000000000..81f411795b --- /dev/null +++ b/mediapipe/util/tflite/tflite_model_loader_test.cc @@ -0,0 +1,68 @@ +#include "mediapipe/util/tflite/tflite_model_loader.h" + +#include +#include + +#include "absl/flags/declare.h" +#include "absl/flags/flag.h" +#include "absl/strings/str_cat.h" +#include "mediapipe/framework/api2/packet.h" +#include "mediapipe/framework/calculator.pb.h" +#include "mediapipe/framework/calculator_context.h" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/calculator_state.h" +#include "mediapipe/framework/legacy_calculator_support.h" +#include "mediapipe/framework/port/gtest.h" +#include "mediapipe/framework/port/status_matchers.h" +#include "mediapipe/framework/tool/tag_map_helper.h" +#include "tensorflow/lite/test_util.h" + +ABSL_DECLARE_FLAG(std::string, resource_root_dir); + +namespace mediapipe { +namespace { + +constexpr char kModelDir[] = "mediapipe/util/tflite/testdata"; +constexpr char kModelFilename[] = "test_model.tflite"; + +class TfLiteModelLoaderTest : public tflite::testing::Test { + void SetUp() override { + // Create a stub calculator state. + CalculatorGraphConfig::Node config; + calculator_state_ = std::make_unique( + "fake_node", 0, "fake_type", config, nullptr); + + // Create a stub calculator context. + calculator_context_ = std::make_unique( + calculator_state_.get(), tool::CreateTagMap({}).value(), + tool::CreateTagMap({}).value()); + } + + protected: + std::unique_ptr calculator_state_; + std::unique_ptr calculator_context_; + std::string model_path_ = absl::StrCat(kModelDir, "/", kModelFilename); +}; + +TEST_F(TfLiteModelLoaderTest, LoadFromPath) { + // TODO: remove LegacyCalculatorSupport usage. + LegacyCalculatorSupport::Scoped scope( + calculator_context_.get()); + MP_ASSERT_OK_AND_ASSIGN(api2::Packet model, + TfLiteModelLoader::LoadFromPath(model_path_)); + EXPECT_NE(model.Get(), nullptr); +} + +TEST_F(TfLiteModelLoaderTest, LoadFromPathRelativeToRootDir) { + absl::SetFlag(&FLAGS_resource_root_dir, kModelDir); + + // TODO: remove LegacyCalculatorSupport usage. + LegacyCalculatorSupport::Scoped scope( + calculator_context_.get()); + MP_ASSERT_OK_AND_ASSIGN(api2::Packet model, + TfLiteModelLoader::LoadFromPath(kModelFilename)); + EXPECT_NE(model.Get(), nullptr); +} + +} // namespace +} // namespace mediapipe