Skip to content

Commit

Permalink
Add op: repeat_interleave.Tensor_out
Browse files Browse the repository at this point in the history
Differential Revision: D67025538

Pull Request resolved: pytorch#7264
  • Loading branch information
manuelcandales authored Dec 10, 2024
1 parent 61bd2b8 commit 343aa0c
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 0 deletions.
2 changes: 2 additions & 0 deletions kernels/aten/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@

- op: repeat.out

- op: repeat_interleave.Tensor_out

- op: reflection_pad1d.out

- op: reflection_pad2d.out
Expand Down
111 changes: 111 additions & 0 deletions kernels/portable/cpu/op_repeat_interleave.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/kernel/kernel_includes.h>

namespace torch {
namespace executor {
namespace native {
namespace {

bool check_repeat_interleave_args(
const Tensor& repeats,
int64_t output_size_value,
int64_t repeats_sum,
Tensor& out) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
repeats.scalar_type() == ScalarType::Int ||
repeats.scalar_type() == ScalarType::Long,
"repeats must be int or long");
ET_LOG_MSG_AND_RETURN_IF_FALSE(repeats.dim() == 1, "repeats must be 1D");
ET_LOG_MSG_AND_RETURN_IF_FALSE(
output_size_value == repeats_sum,
"output_size, if provided, must be equal to repeats.sum()");
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(repeats, out));

if (repeats.scalar_type() == ScalarType::Long) {
const int64_t* const repeats_data = repeats.const_data_ptr<int64_t>();
for (size_t i = 0; i < repeats.numel(); ++i) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
repeats_data[i] >= 0, "repeats cannot be negative");
}
} else {
const int32_t* const repeats_data = repeats.const_data_ptr<int32_t>();
for (size_t i = 0; i < repeats.numel(); ++i) {
ET_LOG_MSG_AND_RETURN_IF_FALSE(
repeats_data[i] >= 0, "repeats cannot be negative");
}
}

return true;
}

} // namespace

using Tensor = exec_aten::Tensor;

Tensor& repeat_interleave_Tensor_out(
KernelRuntimeContext& ctx,
const Tensor& repeats,
exec_aten::optional<int64_t> output_size,
Tensor& out) {
(void)ctx;

int64_t repeats_sum = 0;

constexpr auto name = "repeat_interleave.Tensor_out";

ET_SWITCH_TWO_TYPES(Int, Long, repeats.scalar_type(), ctx, name, CTYPE, [&] {
const CTYPE* repeats_data = repeats.const_data_ptr<CTYPE>();
for (size_t ix = 0; ix < repeats.numel(); ++ix) {
repeats_sum += static_cast<int64_t>(repeats_data[ix]);
}
});

int64_t output_size_value =
output_size.has_value() ? output_size.value() : repeats_sum;

ET_KERNEL_CHECK(
ctx,
check_repeat_interleave_args(
repeats, output_size_value, repeats_sum, out),
InvalidArgument,
out);

ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(repeats, out), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx, tensor_is_default_dim_order(repeats), InvalidArgument, out);

ET_KERNEL_CHECK_MSG(
ctx,
resize_tensor(
out, {static_cast<exec_aten::SizesType>(output_size_value)}) ==
Error::Ok,
InvalidArgument,
out,
"Failed to resize output tensor.");

ET_SWITCH_TWO_TYPES(Int, Long, repeats.scalar_type(), ctx, name, CTYPE, [&] {
const CTYPE* repeats_data = repeats.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
size_t out_ix = 0;
for (size_t ix = 0; ix < repeats.numel(); ix++) {
for (CTYPE i = 0; i < repeats_data[ix]; i++, out_ix++) {
out_data[out_ix] = static_cast<CTYPE>(ix);
}
}
});

return out;
}

} // namespace native
} // namespace executor
} // namespace torch
5 changes: 5 additions & 0 deletions kernels/portable/functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,11 @@
- arg_meta: null
kernel_name: torch::executor::repeat_out

- op: repeat_interleave.Tensor_out
kernels:
- arg_meta: null
kernel_name: torch::executor::repeat_interleave_Tensor_out

- op: reflection_pad1d.out
kernels:
- arg_meta: null
Expand Down
43 changes: 43 additions & 0 deletions kernels/test/op_repeat_interleave_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
#include <executorch/kernels/test/TestUtil.h>
#include <executorch/runtime/core/exec_aten/exec_aten.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>

using namespace ::testing;
using exec_aten::optional;
using exec_aten::ScalarType;
using exec_aten::Tensor;
using torch::executor::testing::TensorFactory;

class OpRepeatInterleaveTensorOutTest : public OperatorTest {
protected:
Tensor& op_repeat_out(
const Tensor& repeats,
optional<int64_t> output_size,
Tensor& out) {
return torch::executor::aten::repeat_interleave_outf(
context_, repeats, output_size, out);
}
};

TEST_F(OpRepeatInterleaveTensorOutTest, SmokeTest) {
TensorFactory<ScalarType::Int> tf;

Tensor repeats = tf.make({3}, {2, 3, 1});

std::vector<int64_t> repeats_vec = {3, 4, 5, 6};
Tensor out = tf.zeros({6});
Tensor expected = tf.make({6}, {0, 0, 1, 1, 1, 2});
Tensor ret = op_repeat_out(repeats, 6, out);
EXPECT_TENSOR_EQ(ret, out);
EXPECT_TENSOR_EQ(ret, expected);
}
1 change: 1 addition & 0 deletions kernels/test/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ def define_common_targets():
_common_op_test("op_relu_test", ["aten", "portable"])
_common_op_test("op_remainder_test", ["aten", "portable"])
_common_op_test("op_repeat_test", ["aten", "portable"])
_common_op_test("op_repeat_interleave_test", ["aten", "portable"])
_common_op_test("op_reflection_pad1d_test", ["aten", "portable"])
_common_op_test("op_reflection_pad2d_test", ["aten", "portable"])
_common_op_test("op_reflection_pad3d_test", ["aten", "portable"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,9 @@ ATEN_OPS = (
"//executorch/kernels/portable/cpu/util:repeat_util",
],
),
op_target(
name = "op_repeat_interleave",
),
op_target(
name = "op_replication_pad1d",
deps = [
Expand Down

0 comments on commit 343aa0c

Please sign in to comment.