diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc index 66456345d7..c373cd2236 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include +#include #include #include #include @@ -105,6 +105,8 @@ namespace mpms = mediapipe::mediasequence; // } // } namespace { +constexpr int kMaxProtoBytes = std::numeric_limits::max(); + uint8_t ConvertFloatToByte(const float float_value) { float clamped_value = std::clamp(0.0f, 1.0f, float_value); return static_cast(clamped_value * 255.0 + .5f); @@ -361,15 +363,21 @@ class PackMediaSequenceCalculator : public CalculatorBase { } } - absl::Status VerifySize() { - constexpr int kMaxProtoBytes = INT_MAX; + absl::Status VerifySize(const PackMediaSequenceCalculatorOptions& options) { + if (!options.skip_large_sequences()) { + return absl::OkStatus(); + } + + const int max_bytes = (options.max_sequence_bytes() > 0) + ? options.max_sequence_bytes() + : kMaxProtoBytes; std::string id = mpms::HasExampleId(*sequence_) ? mpms::GetExampleId(*sequence_) : "example"; - RET_CHECK_LT(sequence_->ByteSizeLong(), kMaxProtoBytes) - << "sequence '" << id - << "' would be too many bytes to serialize after adding features."; + RET_CHECK_LT(sequence_->ByteSizeLong(), max_bytes) + << "sequence '" << id << "' with " << sequence_->ByteSizeLong() + << " bytes would be more than " << max_bytes << " bytes."; return absl::OkStatus(); } @@ -381,9 +389,7 @@ class PackMediaSequenceCalculator : public CalculatorBase { options.reconcile_region_annotations(), sequence_.get())); } - if (options.skip_large_sequences()) { - RET_CHECK_OK(VerifySize()); - } + RET_CHECK_OK(VerifySize(options)); if (options.output_only_if_all_present()) { absl::Status status = VerifySequence(); if (!status.ok()) { diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto index cc6c2ffda4..a5c7bbbdf7 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator.proto @@ -65,6 +65,9 @@ message PackMediaSequenceCalculatorOptions { // If true, will return an error status if an output sequence would be too // many bytes to serialize. optional bool skip_large_sequences = 7 [default = true]; + // If > 0, will return an error status if an output sequence would be too + // many bytes to serialize. Otherwise uses int max. + optional int32 max_sequence_bytes = 10 [default = -1]; // If true/false, outputs the SequenceExample at timestamp 0/PostStream. optional bool output_as_zero_timestamp = 8 [default = false]; diff --git a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc index d8fbc94d53..fee2a8b48c 100644 --- a/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc +++ b/mediapipe/calculators/tensorflow/pack_media_sequence_calculator_test.cc @@ -85,8 +85,9 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test { const bool replace_instead_of_append, const bool output_as_zero_timestamp = false, const bool add_empty_labels = false, - const std::vector& input_side_packets = { - "SEQUENCE_EXAMPLE:input_sequence"}) { + const std::vector& input_side_packets = + {"SEQUENCE_EXAMPLE:input_sequence"}, + const int32_t max_sequence_bytes = -1) { CalculatorGraphConfig::Node config; config.set_calculator("PackMediaSequenceCalculator"); for (const std::string& side_packet : input_side_packets) { @@ -103,6 +104,7 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test { options->set_replace_data_instead_of_append(replace_instead_of_append); options->set_output_as_zero_timestamp(output_as_zero_timestamp); options->set_add_empty_labels(add_empty_labels); + options->set_max_sequence_bytes(max_sequence_bytes); runner_ = ::absl::make_unique(config); } @@ -1987,5 +1989,37 @@ TEST_F(PackMediaSequenceCalculatorTest, TestTooLargeInputFailsSoftly) { ASSERT_FALSE(runner_->Run().ok()); } +TEST_F(PackMediaSequenceCalculatorTest, SkipLargeSequence) { + SetUpCalculator({"IMAGE:images"}, {}, false, true, false, false, + {"SEQUENCE_EXAMPLE:input_sequence"}, + /*max_sequence_bytes=*/10); + auto input_sequence = ::absl::make_unique(); + std::string test_video_id = "test_video_id"; + mpms::SetClipMediaId(test_video_id, input_sequence.get()); + cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255)); + std::vector bytes; + ASSERT_TRUE( + cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1})); + OpenCvImageEncoderCalculatorResults encoded_image; + encoded_image.set_encoded_image(bytes.data(), bytes.size()); + encoded_image.set_width(2); + encoded_image.set_height(1); + + int num_images = 2; + for (int i = 0; i < num_images; ++i) { + auto image_ptr = + ::absl::make_unique(encoded_image); + runner_->MutableInputs()->Tag(kImageTag).packets.push_back( + Adopt(image_ptr.release()).At(Timestamp(i))); + } + + runner_->MutableSidePackets()->Tag(kSequenceExampleTag) = + Adopt(input_sequence.release()); + + absl::Status status = runner_->Run(); + EXPECT_THAT(status.ToString(), + ::testing::HasSubstr("bytes would be more than 10 bytes")); +} + } // namespace } // namespace mediapipe