forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add op: repeat_interleave.Tensor_out
Differential Revision: D67025538 Pull Request resolved: pytorch#7264
- Loading branch information
1 parent
61bd2b8
commit 343aa0c
Showing
6 changed files
with
165 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/runtime/kernel/kernel_includes.h> | ||
|
||
namespace torch { | ||
namespace executor { | ||
namespace native { | ||
namespace { | ||
|
||
bool check_repeat_interleave_args( | ||
const Tensor& repeats, | ||
int64_t output_size_value, | ||
int64_t repeats_sum, | ||
Tensor& out) { | ||
ET_LOG_MSG_AND_RETURN_IF_FALSE( | ||
repeats.scalar_type() == ScalarType::Int || | ||
repeats.scalar_type() == ScalarType::Long, | ||
"repeats must be int or long"); | ||
ET_LOG_MSG_AND_RETURN_IF_FALSE(repeats.dim() == 1, "repeats must be 1D"); | ||
ET_LOG_MSG_AND_RETURN_IF_FALSE( | ||
output_size_value == repeats_sum, | ||
"output_size, if provided, must be equal to repeats.sum()"); | ||
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(repeats, out)); | ||
|
||
if (repeats.scalar_type() == ScalarType::Long) { | ||
const int64_t* const repeats_data = repeats.const_data_ptr<int64_t>(); | ||
for (size_t i = 0; i < repeats.numel(); ++i) { | ||
ET_LOG_MSG_AND_RETURN_IF_FALSE( | ||
repeats_data[i] >= 0, "repeats cannot be negative"); | ||
} | ||
} else { | ||
const int32_t* const repeats_data = repeats.const_data_ptr<int32_t>(); | ||
for (size_t i = 0; i < repeats.numel(); ++i) { | ||
ET_LOG_MSG_AND_RETURN_IF_FALSE( | ||
repeats_data[i] >= 0, "repeats cannot be negative"); | ||
} | ||
} | ||
|
||
return true; | ||
} | ||
|
||
} // namespace | ||
|
||
using Tensor = exec_aten::Tensor; | ||
|
||
Tensor& repeat_interleave_Tensor_out( | ||
KernelRuntimeContext& ctx, | ||
const Tensor& repeats, | ||
exec_aten::optional<int64_t> output_size, | ||
Tensor& out) { | ||
(void)ctx; | ||
|
||
int64_t repeats_sum = 0; | ||
|
||
constexpr auto name = "repeat_interleave.Tensor_out"; | ||
|
||
ET_SWITCH_TWO_TYPES(Int, Long, repeats.scalar_type(), ctx, name, CTYPE, [&] { | ||
const CTYPE* repeats_data = repeats.const_data_ptr<CTYPE>(); | ||
for (size_t ix = 0; ix < repeats.numel(); ++ix) { | ||
repeats_sum += static_cast<int64_t>(repeats_data[ix]); | ||
} | ||
}); | ||
|
||
int64_t output_size_value = | ||
output_size.has_value() ? output_size.value() : repeats_sum; | ||
|
||
ET_KERNEL_CHECK( | ||
ctx, | ||
check_repeat_interleave_args( | ||
repeats, output_size_value, repeats_sum, out), | ||
InvalidArgument, | ||
out); | ||
|
||
ET_KERNEL_CHECK( | ||
ctx, tensors_have_same_dim_order(repeats, out), InvalidArgument, out); | ||
|
||
ET_KERNEL_CHECK( | ||
ctx, tensor_is_default_dim_order(repeats), InvalidArgument, out); | ||
|
||
ET_KERNEL_CHECK_MSG( | ||
ctx, | ||
resize_tensor( | ||
out, {static_cast<exec_aten::SizesType>(output_size_value)}) == | ||
Error::Ok, | ||
InvalidArgument, | ||
out, | ||
"Failed to resize output tensor."); | ||
|
||
ET_SWITCH_TWO_TYPES(Int, Long, repeats.scalar_type(), ctx, name, CTYPE, [&] { | ||
const CTYPE* repeats_data = repeats.const_data_ptr<CTYPE>(); | ||
CTYPE* out_data = out.mutable_data_ptr<CTYPE>(); | ||
size_t out_ix = 0; | ||
for (size_t ix = 0; ix < repeats.numel(); ix++) { | ||
for (CTYPE i = 0; i < repeats_data[ix]; i++, out_ix++) { | ||
out_data[out_ix] = static_cast<CTYPE>(ix); | ||
} | ||
} | ||
}); | ||
|
||
return out; | ||
} | ||
|
||
} // namespace native | ||
} // namespace executor | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator | ||
#include <executorch/kernels/test/TestUtil.h> | ||
#include <executorch/runtime/core/exec_aten/exec_aten.h> | ||
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h> | ||
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h> | ||
|
||
using namespace ::testing; | ||
using exec_aten::optional; | ||
using exec_aten::ScalarType; | ||
using exec_aten::Tensor; | ||
using torch::executor::testing::TensorFactory; | ||
|
||
class OpRepeatInterleaveTensorOutTest : public OperatorTest { | ||
protected: | ||
Tensor& op_repeat_out( | ||
const Tensor& repeats, | ||
optional<int64_t> output_size, | ||
Tensor& out) { | ||
return torch::executor::aten::repeat_interleave_outf( | ||
context_, repeats, output_size, out); | ||
} | ||
}; | ||
|
||
TEST_F(OpRepeatInterleaveTensorOutTest, SmokeTest) { | ||
TensorFactory<ScalarType::Int> tf; | ||
|
||
Tensor repeats = tf.make({3}, {2, 3, 1}); | ||
|
||
std::vector<int64_t> repeats_vec = {3, 4, 5, 6}; | ||
Tensor out = tf.zeros({6}); | ||
Tensor expected = tf.make({6}, {0, 0, 1, 1, 1, 2}); | ||
Tensor ret = op_repeat_out(repeats, 6, out); | ||
EXPECT_TENSOR_EQ(ret, out); | ||
EXPECT_TENSOR_EQ(ret, expected); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters