Skip to content

Commit

Permalink
Add an HasError method and a test for ErrorReporter
Browse files Browse the repository at this point in the history
Using error_reporter.message().empty() is inefficient since that will allocate a std::string. This CL adds a HasError() method that directly tests the underlying buffer. Uses the new method in TfLiteModelLoader.

PiperOrigin-RevId: 617797494
  • Loading branch information
MediaPipe Team authored and copybara-github committed Mar 21, 2024
1 parent bab2d51 commit 7f8985d
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 9 deletions.
9 changes: 9 additions & 0 deletions mediapipe/util/tflite/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ cc_library(
],
)

cc_test(
name = "error_reporter_test",
srcs = ["error_reporter_test.cc"],
deps = [
":error_reporter",
"//mediapipe/framework/port:gtest_main",
],
)

# This target has an implementation dependency on TFLite/TFLite-in-GMSCore,
# but it does not have any API dependency on TFLite-in-GMSCore.
cc_library_with_tflite(
Expand Down
2 changes: 2 additions & 0 deletions mediapipe/util/tflite/error_reporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ int ErrorReporter::Report(const char* format, va_list args) {
return num_characters;
}

bool ErrorReporter::HasError() const { return message_[0] != '\0'; }

std::string ErrorReporter::message() { return message_; }

std::string ErrorReporter::previous_message() { return previous_message_; }
Expand Down
14 changes: 7 additions & 7 deletions mediapipe/util/tflite/error_reporter.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
#include "tensorflow/lite/core/api/error_reporter.h"
#include "tensorflow/lite/stateful_error_reporter.h"

namespace mediapipe {
namespace util {
namespace tflite {
namespace mediapipe::util::tflite {

// An ErrorReporter that logs to stderr and captures the last two messages.
class ErrorReporter : public ::tflite::StatefulErrorReporter {
public:
static constexpr int kBufferSize = 1024;

ErrorReporter();

// We declared two functions with name 'Report', so that the variadic Report
Expand All @@ -36,17 +36,17 @@ class ErrorReporter : public ::tflite::StatefulErrorReporter {

int Report(const char* format, std::va_list args) override;

// Returns true if any error was reported.
bool HasError() const;

std::string message() override;
std::string previous_message();

private:
static constexpr int kBufferSize = 1024;
char message_[kBufferSize];
char previous_message_[kBufferSize];
};

} // namespace tflite
} // namespace util
} // namespace mediapipe
} // namespace mediapipe::util::tflite

#endif // MEDIAPIPE_UTIL_TFLITE_ERROR_REPORTER_H_
55 changes: 55 additions & 0 deletions mediapipe/util/tflite/error_reporter_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "mediapipe/util/tflite/error_reporter.h"

#include <string>

#include "mediapipe/framework/port/gtest.h"

namespace mediapipe::util::tflite {
namespace {

TEST(ErrorReporterTest, ReportNoErrors) {
ErrorReporter error_reporter;
EXPECT_FALSE(error_reporter.HasError());
EXPECT_TRUE(error_reporter.message().empty());
EXPECT_TRUE(error_reporter.previous_message().empty());
}

TEST(ErrorReporterTest, ReportOneError) {
ErrorReporter error_reporter;
error_reporter.Report("error %i", 1);
EXPECT_TRUE(error_reporter.HasError());
EXPECT_EQ(error_reporter.message(), "error 1");
EXPECT_TRUE(error_reporter.previous_message().empty());
}

TEST(ErrorReporterTest, ReportTwoErrors) {
ErrorReporter error_reporter;
error_reporter.Report("error %i", 1);
error_reporter.Report("error %i", 2);
EXPECT_TRUE(error_reporter.HasError());
EXPECT_EQ(error_reporter.message(), "error 2");
EXPECT_EQ(error_reporter.previous_message(), "error 1");
}

TEST(ErrorReporterTest, ReportThreeErrors) {
ErrorReporter error_reporter;
error_reporter.Report("error %i", 1);
error_reporter.Report("error %i", 2);
error_reporter.Report("error %i", 3);
EXPECT_TRUE(error_reporter.HasError());
EXPECT_EQ(error_reporter.message(), "error 3");
EXPECT_EQ(error_reporter.previous_message(), "error 2");
}

TEST(ErrorReporterTest, VeryLongErrorIsTruncated) {
ErrorReporter error_reporter;
std::string long_error;
long_error.resize(ErrorReporter::kBufferSize * 2, 'x');
error_reporter.Report(long_error.c_str());
EXPECT_TRUE(error_reporter.HasError());
EXPECT_EQ(error_reporter.message(),
long_error.substr(0, ErrorReporter::kBufferSize - 1));
}

} // namespace
} // namespace mediapipe::util::tflite
3 changes: 1 addition & 2 deletions mediapipe/util/tflite/tflite_model_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ absl::StatusOr<api2::Packet<TfLiteModelPtr>> TfLiteModelLoader::LoadFromPath(
std::unique_ptr<Allocation> allocation =
std::make_unique<MMAPAllocation>(model_path.c_str(), &error_reporter);

bool mmap_allocation_succeeded = error_reporter.message().empty();
if (mmap_allocation_succeeded) {
if (!error_reporter.HasError()) {
auto model = FlatBufferModel::BuildFromAllocation(std::move(allocation),
&error_reporter);
if (model) {
Expand Down

0 comments on commit 7f8985d

Please sign in to comment.