From b2de35c1d178cb8029e24d3ccbd64e1c3bd95cab Mon Sep 17 00:00:00 2001 From: nishpoonia <94543206+nishpoonia@users.noreply.github.com> Date: Wed, 23 Oct 2024 13:46:01 +0530 Subject: [PATCH] Adding full operator (#27) Co-authored-by: dijopaul <87994875+dijopaul@users.noreply.github.com> --- backends/cadence/aot/functions_hifi.yaml | 2 +- .../cadence/hifi/operators/CMakeLists.txt | 1 + backends/cadence/hifi/operators/op_full.cpp | 91 +++++++++++++++++++ 3 files changed, 93 insertions(+), 1 deletion(-) create mode 100644 backends/cadence/hifi/operators/op_full.cpp diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 31d61d5d94..47006430a3 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -65,7 +65,7 @@ - op: full.out kernels: - arg_meta: null - kernel_name: torch::executor::full_out + kernel_name: impl::HiFi::full_out - op: maximum.out kernels: diff --git a/backends/cadence/hifi/operators/CMakeLists.txt b/backends/cadence/hifi/operators/CMakeLists.txt index 77fcbbbaa5..da29ae4a0b 100644 --- a/backends/cadence/hifi/operators/CMakeLists.txt +++ b/backends/cadence/hifi/operators/CMakeLists.txt @@ -24,6 +24,7 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_bmm.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_clamp.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_full.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_maximum.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_minimum.cpp" diff --git a/backends/cadence/hifi/operators/op_full.cpp b/backends/cadence/hifi/operators/op_full.cpp new file mode 100644 index 0000000000..212fe3b306 --- /dev/null +++ b/backends/cadence/hifi/operators/op_full.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +namespace impl { +namespace HiFi { +namespace native { + +using exec_aten::IntArrayRef; +using exec_aten::RuntimeContext; +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::native::utils::extract_scalar; +using torch::executor::native::utils::get_scalar_dtype; + +Tensor& full_out( + RuntimeContext& ctx, + const IntArrayRef sizes, + const Scalar& fill_value, + Tensor& out) { + (void)ctx; + + ScalarType val_type = get_scalar_dtype(fill_value); + ScalarType out_type = out.scalar_type(); + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor(out, sizes) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + constexpr auto name = "full.out"; + + bool optimized = 0; + if (out_type == ScalarType::Long || out_type == ScalarType::Float || + out_type == ScalarType::Byte || out_type == ScalarType::Char) + optimized = 1; + + if (optimized) { + if (out_type == ScalarType::Long) { + int* data_out = out.mutable_data_ptr(); + int64_t val = fill_value.to(); + int val_casted = static_cast(val); + for (size_t i = 0; i < out.numel(); ++i) { + data_out[i] = val_casted; + } + } else if (out_type == ScalarType::Float) { + float* data_out = out.mutable_data_ptr(); + double val = fill_value.to(); + float val_casted = static_cast(val); + xa_nn_memset_f32_f32(data_out, val_casted, out.numel()); + } else if (out_type == ScalarType::Byte || out_type == ScalarType::Char) { + char* data_out = out.mutable_data_ptr(); + int val = fill_value.to(); + memset((void*)data_out, val, out.numel()); + } + + return out; + } + + ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] { + CTYPE_VAL val; + extract_scalar(fill_value, &val); + + ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "full", CTYPE_OUT, [&] { + CTYPE_OUT val_casted = static_cast(val); + auto data_out = out.mutable_data_ptr(); + for (size_t i = 0; i < out.numel(); ++i) { + data_out[i] = val_casted; + } + }); + }); + return out; +} + +} // namespace native +} // namespace HiFi +} // namespace impl