From 343aa0c6e69107c795984d8f6b10d8a04f5da8a5 Mon Sep 17 00:00:00 2001 From: Manuel Candales <42380156+manuelcandales@users.noreply.github.com> Date: Tue, 10 Dec 2024 17:39:10 -0500 Subject: [PATCH] Add op: repeat_interleave.Tensor_out Differential Revision: D67025538 Pull Request resolved: https://github.com/pytorch/executorch/pull/7264 --- kernels/aten/functions.yaml | 2 + kernels/portable/cpu/op_repeat_interleave.cpp | 111 ++++++++++++++++++ kernels/portable/functions.yaml | 5 + kernels/test/op_repeat_interleave_test.cpp | 43 +++++++ kernels/test/targets.bzl | 1 + .../kernels/portable/op_registration_util.bzl | 3 + 6 files changed, 165 insertions(+) create mode 100644 kernels/portable/cpu/op_repeat_interleave.cpp create mode 100644 kernels/test/op_repeat_interleave_test.cpp diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index f2f18f51c8..ebcd86d851 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -313,6 +313,8 @@ - op: repeat.out +- op: repeat_interleave.Tensor_out + - op: reflection_pad1d.out - op: reflection_pad2d.out diff --git a/kernels/portable/cpu/op_repeat_interleave.cpp b/kernels/portable/cpu/op_repeat_interleave.cpp new file mode 100644 index 0000000000..c36c8deea9 --- /dev/null +++ b/kernels/portable/cpu/op_repeat_interleave.cpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch { +namespace executor { +namespace native { +namespace { + +bool check_repeat_interleave_args( + const Tensor& repeats, + int64_t output_size_value, + int64_t repeats_sum, + Tensor& out) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + repeats.scalar_type() == ScalarType::Int || + repeats.scalar_type() == ScalarType::Long, + "repeats must be int or long"); + ET_LOG_MSG_AND_RETURN_IF_FALSE(repeats.dim() == 1, "repeats must be 1D"); + ET_LOG_MSG_AND_RETURN_IF_FALSE( + output_size_value == repeats_sum, + "output_size, if provided, must be equal to repeats.sum()"); + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(repeats, out)); + + if (repeats.scalar_type() == ScalarType::Long) { + const int64_t* const repeats_data = repeats.const_data_ptr(); + for (size_t i = 0; i < repeats.numel(); ++i) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + repeats_data[i] >= 0, "repeats cannot be negative"); + } + } else { + const int32_t* const repeats_data = repeats.const_data_ptr(); + for (size_t i = 0; i < repeats.numel(); ++i) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + repeats_data[i] >= 0, "repeats cannot be negative"); + } + } + + return true; +} + +} // namespace + +using Tensor = exec_aten::Tensor; + +Tensor& repeat_interleave_Tensor_out( + KernelRuntimeContext& ctx, + const Tensor& repeats, + exec_aten::optional output_size, + Tensor& out) { + (void)ctx; + + int64_t repeats_sum = 0; + + constexpr auto name = "repeat_interleave.Tensor_out"; + + ET_SWITCH_TWO_TYPES(Int, Long, repeats.scalar_type(), ctx, name, CTYPE, [&] { + const CTYPE* repeats_data = repeats.const_data_ptr(); + for (size_t ix = 0; ix < repeats.numel(); ++ix) { + repeats_sum += static_cast(repeats_data[ix]); + } + }); + + int64_t output_size_value = + output_size.has_value() ? output_size.value() : repeats_sum; + + ET_KERNEL_CHECK( + ctx, + check_repeat_interleave_args( + repeats, output_size_value, repeats_sum, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(repeats, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, tensor_is_default_dim_order(repeats), InvalidArgument, out); + + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor( + out, {static_cast(output_size_value)}) == + Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + ET_SWITCH_TWO_TYPES(Int, Long, repeats.scalar_type(), ctx, name, CTYPE, [&] { + const CTYPE* repeats_data = repeats.const_data_ptr(); + CTYPE* out_data = out.mutable_data_ptr(); + size_t out_ix = 0; + for (size_t ix = 0; ix < repeats.numel(); ix++) { + for (CTYPE i = 0; i < repeats_data[ix]; i++, out_ix++) { + out_data[out_ix] = static_cast(ix); + } + } + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/functions.yaml b/kernels/portable/functions.yaml index 266b5e446f..0da9917214 100644 --- a/kernels/portable/functions.yaml +++ b/kernels/portable/functions.yaml @@ -712,6 +712,11 @@ - arg_meta: null kernel_name: torch::executor::repeat_out +- op: repeat_interleave.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::repeat_interleave_Tensor_out + - op: reflection_pad1d.out kernels: - arg_meta: null diff --git a/kernels/test/op_repeat_interleave_test.cpp b/kernels/test/op_repeat_interleave_test.cpp new file mode 100644 index 0000000000..c4056737be --- /dev/null +++ b/kernels/test/op_repeat_interleave_test.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include + +using namespace ::testing; +using exec_aten::optional; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpRepeatInterleaveTensorOutTest : public OperatorTest { + protected: + Tensor& op_repeat_out( + const Tensor& repeats, + optional output_size, + Tensor& out) { + return torch::executor::aten::repeat_interleave_outf( + context_, repeats, output_size, out); + } +}; + +TEST_F(OpRepeatInterleaveTensorOutTest, SmokeTest) { + TensorFactory tf; + + Tensor repeats = tf.make({3}, {2, 3, 1}); + + std::vector repeats_vec = {3, 4, 5, 6}; + Tensor out = tf.zeros({6}); + Tensor expected = tf.make({6}, {0, 0, 1, 1, 1, 2}); + Tensor ret = op_repeat_out(repeats, 6, out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expected); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 2e7c34f147..2dd019e1b3 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -283,6 +283,7 @@ def define_common_targets(): _common_op_test("op_relu_test", ["aten", "portable"]) _common_op_test("op_remainder_test", ["aten", "portable"]) _common_op_test("op_repeat_test", ["aten", "portable"]) + _common_op_test("op_repeat_interleave_test", ["aten", "portable"]) _common_op_test("op_reflection_pad1d_test", ["aten", "portable"]) _common_op_test("op_reflection_pad2d_test", ["aten", "portable"]) _common_op_test("op_reflection_pad3d_test", ["aten", "portable"]) diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index b88e83b058..8a04277c95 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -1001,6 +1001,9 @@ ATEN_OPS = ( "//executorch/kernels/portable/cpu/util:repeat_util", ], ), + op_target( + name = "op_repeat_interleave", + ), op_target( name = "op_replication_pad1d", deps = [