Skip to content

Commit

Permalink
Enable BeginLoopCalculator for move-only types (e.g. Tensor) with…
Browse files Browse the repository at this point in the history
…out `Packet::Consume` usage and copyable types without copying unless it's a fundamental type.

PiperOrigin-RevId: 633693734
  • Loading branch information
MediaPipe Team authored and copybara-github committed May 14, 2024
1 parent 7a62bb5 commit 22a978d
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 29 deletions.
2 changes: 2 additions & 0 deletions mediapipe/calculators/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ cc_test(
"//mediapipe/framework/port:gtest_main",
"//mediapipe/framework/port:parse_text_proto",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/types:span",
],
)

Expand Down
51 changes: 48 additions & 3 deletions mediapipe/calculators/core/begin_end_loop_calculator_graph_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <cstring>
#include <string>
#include <utility>
#include <vector>

#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/types/span.h"
#include "mediapipe/calculators/core/begin_loop_calculator.h"
#include "mediapipe/calculators/core/end_loop_calculator.h"
#include "mediapipe/framework/calculator_contract.h"
Expand Down Expand Up @@ -514,6 +518,37 @@ TEST_F(BeginEndLoopCalculatorGraphWithClonedInputsTest, MultipleVectors) {
PacketOfIntsEq(input_timestamp2, std::vector<int>{6, 9})));
}

class TestTensorCpuCopyCalculator : public CalculatorBase {
public:
static absl::Status GetContract(CalculatorContract* cc) {
cc->Inputs().Index(0).Set<Tensor>();
cc->Outputs().Index(0).Set<Tensor>();
return absl::OkStatus();
}

absl::Status Open(CalculatorContext* cc) override {
cc->SetOffset(TimestampDiff(0));
return absl::OkStatus();
}

absl::Status Process(CalculatorContext* cc) override {
const Tensor& in_tensor = cc->Inputs().Index(0).Get<Tensor>();
const Tensor::CpuReadView in_view = in_tensor.GetCpuReadView();
const void* in_data = in_view.buffer<void>();

Tensor out_tensor(in_tensor.element_type(), in_tensor.shape());
auto out_view = out_tensor.GetCpuWriteView();
void* out_data = out_view.buffer<void>();

std::memcpy(out_data, in_data, in_tensor.bytes());

cc->Outputs().Index(0).AddPacket(
MakePacket<Tensor>(std::move(out_tensor)).At(cc->InputTimestamp()));
return absl::OkStatus();
}
};
REGISTER_CALCULATOR(TestTensorCpuCopyCalculator);

absl::Status InitBeginEndTensorLoopTestGraph(
CalculatorGraph& graph, std::vector<Packet>& output_packets) {
auto graph_config = ParseTextProtoOrDie<CalculatorGraphConfig>(
Expand All @@ -527,13 +562,13 @@ absl::Status InitBeginEndTensorLoopTestGraph(
output_stream: "BATCH_END:timestamp"
}
node {
calculator: "PassThroughCalculator"
calculator: "TestTensorCpuCopyCalculator"
input_stream: "tensor"
output_stream: "passed_tensor"
output_stream: "copied_tensor"
}
node {
calculator: "EndLoopTensorCalculator"
input_stream: "ITEM:passed_tensor"
input_stream: "ITEM:copied_tensor"
input_stream: "BATCH_END:timestamp"
output_stream: "ITERABLE:output_tensors"
}
Expand All @@ -555,6 +590,11 @@ TEST(BeginEndTensorLoopCalculatorGraphTest, SingleNonEmptyVector) {
for (int i = 0; i < 4; i++) {
tensors.emplace_back(Tensor::ElementType::kFloat32,
Tensor::Shape{4, 3, 2, 1});
auto write_view = tensors.back().GetCpuWriteView();
float* data = write_view.buffer<float>();

// Populate with tensor index in the vector.
std::fill(data, data + tensors.back().element_size(), i);
}
Packet vector_packet =
MakePacket<std::vector<mediapipe::Tensor>>(std::move(tensors));
Expand All @@ -570,6 +610,11 @@ TEST(BeginEndTensorLoopCalculatorGraphTest, SingleNonEmptyVector) {
for (int i = 0; i < output_tensors.size(); i++) {
EXPECT_THAT(output_tensors[i].shape().dims,
testing::ElementsAre(4, 3, 2, 1));
const float* data = output_tensors[i].GetCpuReadView().buffer<float>();

// Expect every element is equal to tensor index.
EXPECT_THAT(absl::MakeSpan(data, output_tensors[i].element_size()),
testing::Each(i));
}

MP_ASSERT_OK(graph.CloseAllPacketSources());
Expand Down
44 changes: 18 additions & 26 deletions mediapipe/calculators/core/begin_loop_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#ifndef MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_
#define MEDIAPIPE_CALCULATORS_CORE_BEGIN_LOOP_CALCULATOR_H_

#include <type_traits>
#include <utility>

#include "absl/status/status.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_contract.h"
Expand Down Expand Up @@ -135,37 +138,26 @@ class BeginLoopCalculator : public CalculatorBase {
absl::Status Process(CalculatorContext* cc) final {
Timestamp last_timestamp = loop_internal_timestamp_;
if (!cc->Inputs().Tag("ITERABLE").IsEmpty()) {
// Try to consume the ITERABLE packet if possible to obtain the ownership
// and emit the item packets by moving them.
// If the ITERABLE packet is not consumable, then try to copy each item
// instead. If the ITEM type is not copy constructible, an error will be
// returned.
auto iterable_ptr_or =
cc->Inputs().Tag("ITERABLE").Value().Consume<IterableT>();
if (iterable_ptr_or.ok()) {
for (auto& item : *iterable_ptr_or.value()) {
Packet item_packet = MakePacket<ItemT>(std::move(item));
const Packet& iterable = cc->Inputs().Tag("ITERABLE").Value();
if constexpr (std::is_fundamental_v<ItemT>) {
for (ItemT item : iterable.Get<IterableT>()) {
cc->Outputs().Tag("ITEM").AddPacket(
item_packet.At(loop_internal_timestamp_));
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
ForwardClonePackets(cc, loop_internal_timestamp_);
++loop_internal_timestamp_;
}
} else {
if constexpr (std::is_copy_constructible<ItemT>()) {
const IterableT& collection =
cc->Inputs().Tag("ITERABLE").template Get<IterableT>();
for (const auto& item : collection) {
cc->Outputs().Tag("ITEM").AddPacket(
MakePacket<ItemT>(item).At(loop_internal_timestamp_));
ForwardClonePackets(cc, loop_internal_timestamp_);
++loop_internal_timestamp_;
}
} else {
return absl::InternalError(
"The element type is not copiable. Consider making the "
"BeginLoopCalculator the sole owner of the input packet so that "
"the "
"items can be consumed and moved.");
for (const auto& item : iterable.Get<IterableT>()) {
Packet item_packet = PointToForeign(
&item, /*cleanup=*/[iterable_packet_copy = iterable]() mutable {
// Captures a copy of iterable packet and destroys it when
// packet representing an item is destroyed.
iterable_packet_copy = Packet();
});
cc->Outputs().Tag("ITEM").AddPacket(
item_packet.At(loop_internal_timestamp_));
ForwardClonePackets(cc, loop_internal_timestamp_);
++loop_internal_timestamp_;
}
}
}
Expand Down

0 comments on commit 22a978d

Please sign in to comment.