Skip to content

Commit

Permalink
Add option to set max sequence size in PackMediaSequenceCalculator in…
Browse files Browse the repository at this point in the history
…stead of having it hard coded.

PiperOrigin-RevId: 698699219
  • Loading branch information
MediaPipe Team authored and copybara-github committed Nov 21, 2024
1 parent 256bb14 commit 1077b73
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 11 deletions.
24 changes: 15 additions & 9 deletions mediapipe/calculators/tensorflow/pack_media_sequence_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <climits>
#include <cstdint>
#include <limits>
#include <optional>
#include <string>
#include <vector>
Expand Down Expand Up @@ -105,6 +105,8 @@ namespace mpms = mediapipe::mediasequence;
// }
// }
namespace {
constexpr int kMaxProtoBytes = std::numeric_limits<int>::max();

uint8_t ConvertFloatToByte(const float float_value) {
float clamped_value = std::clamp(0.0f, 1.0f, float_value);
return static_cast<uint8_t>(clamped_value * 255.0 + .5f);
Expand Down Expand Up @@ -361,15 +363,21 @@ class PackMediaSequenceCalculator : public CalculatorBase {
}
}

absl::Status VerifySize() {
constexpr int kMaxProtoBytes = INT_MAX;
absl::Status VerifySize(const PackMediaSequenceCalculatorOptions& options) {
if (!options.skip_large_sequences()) {
return absl::OkStatus();
}

const int max_bytes = (options.max_sequence_bytes() > 0)
? options.max_sequence_bytes()
: kMaxProtoBytes;

std::string id = mpms::HasExampleId(*sequence_)
? mpms::GetExampleId(*sequence_)
: "example";
RET_CHECK_LT(sequence_->ByteSizeLong(), kMaxProtoBytes)
<< "sequence '" << id
<< "' would be too many bytes to serialize after adding features.";
RET_CHECK_LT(sequence_->ByteSizeLong(), max_bytes)
<< "sequence '" << id << "' with " << sequence_->ByteSizeLong()
<< " bytes would be more than " << max_bytes << " bytes.";
return absl::OkStatus();
}

Expand All @@ -381,9 +389,7 @@ class PackMediaSequenceCalculator : public CalculatorBase {
options.reconcile_region_annotations(), sequence_.get()));
}

if (options.skip_large_sequences()) {
RET_CHECK_OK(VerifySize());
}
RET_CHECK_OK(VerifySize(options));
if (options.output_only_if_all_present()) {
absl::Status status = VerifySequence();
if (!status.ok()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ message PackMediaSequenceCalculatorOptions {
// If true, will return an error status if an output sequence would be too
// many bytes to serialize.
optional bool skip_large_sequences = 7 [default = true];
// If > 0, will return an error status if an output sequence would be too
// many bytes to serialize. Otherwise uses int max.
optional int32 max_sequence_bytes = 10 [default = -1];

// If true/false, outputs the SequenceExample at timestamp 0/PostStream.
optional bool output_as_zero_timestamp = 8 [default = false];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test {
const bool replace_instead_of_append,
const bool output_as_zero_timestamp = false,
const bool add_empty_labels = false,
const std::vector<std::string>& input_side_packets = {
"SEQUENCE_EXAMPLE:input_sequence"}) {
const std::vector<std::string>& input_side_packets =
{"SEQUENCE_EXAMPLE:input_sequence"},
const int32_t max_sequence_bytes = -1) {
CalculatorGraphConfig::Node config;
config.set_calculator("PackMediaSequenceCalculator");
for (const std::string& side_packet : input_side_packets) {
Expand All @@ -103,6 +104,7 @@ class PackMediaSequenceCalculatorTest : public ::testing::Test {
options->set_replace_data_instead_of_append(replace_instead_of_append);
options->set_output_as_zero_timestamp(output_as_zero_timestamp);
options->set_add_empty_labels(add_empty_labels);
options->set_max_sequence_bytes(max_sequence_bytes);
runner_ = ::absl::make_unique<CalculatorRunner>(config);
}

Expand Down Expand Up @@ -1987,5 +1989,37 @@ TEST_F(PackMediaSequenceCalculatorTest, TestTooLargeInputFailsSoftly) {
ASSERT_FALSE(runner_->Run().ok());
}

TEST_F(PackMediaSequenceCalculatorTest, SkipLargeSequence) {
SetUpCalculator({"IMAGE:images"}, {}, false, true, false, false,
{"SEQUENCE_EXAMPLE:input_sequence"},
/*max_sequence_bytes=*/10);
auto input_sequence = ::absl::make_unique<tf::SequenceExample>();
std::string test_video_id = "test_video_id";
mpms::SetClipMediaId(test_video_id, input_sequence.get());
cv::Mat image(2, 3, CV_8UC3, cv::Scalar(0, 0, 255));
std::vector<uchar> bytes;
ASSERT_TRUE(
cv::imencode(".jpg", image, bytes, {cv::IMWRITE_HDR_COMPRESSION, 1}));
OpenCvImageEncoderCalculatorResults encoded_image;
encoded_image.set_encoded_image(bytes.data(), bytes.size());
encoded_image.set_width(2);
encoded_image.set_height(1);

int num_images = 2;
for (int i = 0; i < num_images; ++i) {
auto image_ptr =
::absl::make_unique<OpenCvImageEncoderCalculatorResults>(encoded_image);
runner_->MutableInputs()->Tag(kImageTag).packets.push_back(
Adopt(image_ptr.release()).At(Timestamp(i)));
}

runner_->MutableSidePackets()->Tag(kSequenceExampleTag) =
Adopt(input_sequence.release());

absl::Status status = runner_->Run();
EXPECT_THAT(status.ToString(),
::testing::HasSubstr("bytes would be more than 10 bytes"));
}

} // namespace
} // namespace mediapipe

0 comments on commit 1077b73

Please sign in to comment.