diff --git a/backends/cadence/hifi/kernels/CMakeLists.txt b/backends/cadence/hifi/kernels/CMakeLists.txt index 07da3400b9..989aa9ecf2 100644 --- a/backends/cadence/hifi/kernels/CMakeLists.txt +++ b/backends/cadence/hifi/kernels/CMakeLists.txt @@ -11,6 +11,7 @@ add_library( ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_add_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_div_broadcast_f32.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_floor_div_broadcast_f32.c @@ -25,4 +26,4 @@ target_include_directories( ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/nnlib-hifi4/xa_nnlib/algo/ndsp/hifi4/include/ ) -target_link_libraries(cadence_kernels PRIVATE xa_nnlib) \ No newline at end of file +target_link_libraries(cadence_kernels PRIVATE xa_nnlib) diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index 8d3354352a..15b576ed78 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -63,6 +63,14 @@ extern "C" WORD32 xa_nn_elm_mul_broadcast_4D_f32xf32_f32(FLOAT32 * __restrict__ const WORD32 *const p_inp1_shape, const FLOAT32 * __restrict__ p_inp2, const WORD32 *const p_inp2_shape); + +extern "C" WORD32 xa_nn_transpose_32_32(WORD32 * __restrict__ p_out + ,const WORD32 *const p_out_shape + ,const WORD32 * __restrict__ p_inp + ,const WORD32 *const p_inp_shape + ,const WORD32 * __restrict__ p_permute_vec + ,WORD32 num_out_dims + ,WORD32 num_inp_dims); namespace impl { namespace HiFi { diff --git a/backends/cadence/hifi/operators/CMakeLists.txt b/backends/cadence/hifi/operators/CMakeLists.txt index 9fc35f436a..fc63eba88e 100644 --- a/backends/cadence/hifi/operators/CMakeLists.txt +++ b/backends/cadence/hifi/operators/CMakeLists.txt @@ -23,6 +23,7 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_full.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_view_copy.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_view_copy.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" @@ -38,6 +39,8 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_rsqrt.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_mul.cpp" @@ -46,7 +49,7 @@ set(_aten_ops__srcs "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_softmax.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_empty.cpp" diff --git a/backends/cadence/hifi/operators/op_permute_copy.cpp b/backends/cadence/hifi/operators/op_permute_copy.cpp new file mode 100644 index 0000000000..135fd8e5d9 --- /dev/null +++ b/backends/cadence/hifi/operators/op_permute_copy.cpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include "kernels.h" + +namespace torch { +namespace executor { +namespace native { + +using SizesType = exec_aten::SizesType; +using Tensor = exec_aten::Tensor; +using IntArrayRef = exec_aten::ArrayRef; + +namespace { + +void increment_coordinate_permuted( + const Tensor& tensor, + size_t* const coordinate, + IntArrayRef dims) { + for (int i = dims.size() - 1; i >= 0; i--) { + size_t d = dims[i] >= 0 ? dims[i] : dims[i] + tensor.dim(); + coordinate[d]++; + if (coordinate[d] == tensor.size(d)) { + coordinate[d] = 0; + } else { + return; + } + } +} + +} // namespace + +Tensor& permute_copy_out( + RuntimeContext& ctx, + const Tensor& in, + IntArrayRef dims, + Tensor& out) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, check_permute_copy_args(in, dims, out), InvalidArgument, out); + + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; + size_t expected_out_dim = 0; + get_permute_copy_out_target_size( + in, dims, expected_out_size, &expected_out_dim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok, + InvalidArgument, + out); + + const auto in_type = out.scalar_type(); + + if(in_type == ScalarType::Float) + { + WORD32 * p_inp = (WORD32 *)in.const_data_ptr(); + WORD32 * p_out = (WORD32 *)out.mutable_data_ptr(); + + WORD32 num_inp_dims = in.dim(); + WORD32 num_out_dims = num_inp_dims; + + WORD32 p_inp_shape[5]; + WORD32 p_out_shape[5]; + WORD32 p_permute_vec[5]; + + for(int i = 0; i < num_inp_dims; i++) + { + p_inp_shape[i] = in.size(i); + p_out_shape[i] = in.size(dims[i]); + p_permute_vec[i] = dims[i]; + } + + WORD32 val = xa_nn_transpose_32_32(p_out + ,p_out_shape + ,p_inp + ,p_inp_shape + ,p_permute_vec + ,num_out_dims + ,num_inp_dims); + + } + else if(in_type == ScalarType::Char) + { + WORD8 * p_inp = (WORD8 *)in.const_data_ptr(); + WORD8 * p_out = (WORD8 *)out.mutable_data_ptr(); + + WORD32 num_inp_dims = in.dim(); + WORD32 num_out_dims = num_inp_dims; + + WORD32 p_inp_shape[5]; + WORD32 p_out_shape[5]; + WORD32 p_permute_vec[5]; + + for(int i = 0; i < num_inp_dims; i++) + { + p_inp_shape[i] = in.size(i); + p_out_shape[i] = in.size(dims[i]); + p_permute_vec[i] = dims[i]; + } + + p_inp_shape[num_inp_dims] = 4; + p_out_shape[num_inp_dims] = 4; + + + WORD32 val = xa_nn_transpose_8_8(p_out + ,p_out_shape + ,p_inp + ,p_inp_shape + ,p_permute_vec + ,num_out_dims + ,num_inp_dims); + + } + else + { + // in and out must be the same dtype + ET_SWITCH_ALL_TYPES(in_type, ctx, "permute_copy.out", CTYPE, [&] { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* const out_data = out.mutable_data_ptr(); + + size_t in_coord[kTensorDimensionLimit] = {0}; + + for (size_t i = 0; i < out.numel(); ++i) { + out_data[i] = in_data[coordinateToIndex(in, in_coord)]; + increment_coordinate_permuted(in, in_coord, dims); + } + }); + + } + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/backends/cadence/hifi/operators/op_sigmoid.cpp b/backends/cadence/hifi/operators/op_sigmoid.cpp new file mode 100644 index 0000000000..e623a7876c --- /dev/null +++ b/backends/cadence/hifi/operators/op_sigmoid.cpp @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include +#include "kernels.h" + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; + +Tensor& sigmoid_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { + (void)ctx; + + ET_KERNEL_CHECK( + ctx, in.scalar_type() != ScalarType::Bool, InvalidArgument, out); + ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out); + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor(out, in.sizes()) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + ScalarType in_type = in.scalar_type(); + ScalarType out_type = out.scalar_type(); + + if(in_type == ScalarType::Float) + { + float* data_in = in.mutable_data_ptr(); + float* data_out = out.mutable_data_ptr(); + xa_nn_vec_sigmoid_f32_f32(data_out, data_in, in.numel()); + } + else + { + ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() { + ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() { + apply_unary_map_fn( + [](const CTYPE_IN val_in) { + // perform math in double to preserve precision + double in_casted = static_cast(val_in); + double out_val = 1.0 / (1.0 + exp(-in_casted)); + return static_cast(out_val); + }, + in.const_data_ptr(), + out.mutable_data_ptr(), + in.numel()); + }); + }); + } + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/backends/cadence/hifi/operators/op_sub.cpp b/backends/cadence/hifi/operators/op_sub.cpp new file mode 100644 index 0000000000..a9aff27da3 --- /dev/null +++ b/backends/cadence/hifi/operators/op_sub.cpp @@ -0,0 +1,252 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include "kernels.h" + +#define NNLIB_MAX_DIM 4 /* Add fallback if broadcast and dim > 4 */ + +namespace torch { +namespace executor { +namespace native { +namespace { + +template < + bool can_cast, + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct SubInner; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct SubInner { + static void + run(const Tensor& a, const Tensor& b, CTYPE_IN alpha_val, Tensor& out) { + apply_binary_elementwise_fn( + // NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue) + [alpha_val](const CTYPE_A val_a, const CTYPE_B val_b) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN b_casted = static_cast(val_b); + CTYPE_IN value = a_casted - alpha_val * b_casted; + + return static_cast(value); + }, + a, + b, + out); + } +}; + +template +struct ReportCanCastBug { + static void run(const Tensor&, const Tensor&, CTYPE_IN, Tensor&) { + ET_DCHECK_MSG(false, "BUG: canCast should have been checked above"); + } +}; + +template < + typename CTYPE_A, + typename CTYPE_B, + typename CTYPE_IN, + typename CTYPE_OUT> +struct SubInner + : public ReportCanCastBug {}; + +} // namespace + +Tensor& sub_out( + RuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out); + + ScalarType a_type = a.scalar_type(); + ScalarType b_type = b.scalar_type(); + ScalarType alpha_type = utils::get_scalar_dtype(alpha); + ScalarType common_type = promoteTypes(a_type, b_type, /*half_to_float*/ true); + ScalarType out_type = out.scalar_type(); + + ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, check_alpha_type(alpha_type, common_type), InvalidArgument, out); + + float alpha_val; + utils::extract_scalar(alpha, &alpha_val); + + constexpr auto name = "sub.out"; + + + int a_dim = a.dim(), b_dim = b.dim(), out_dim = out.dim(); + int fall_back = 0; + + if( (a_dim == 0) || (b_dim == 0) ) + { + fall_back = 1; + } + if( (out_type != ScalarType::Float) || (alpha_val != 1.0)) + { + fall_back = 1; + } + + + if(!fall_back) + { + /*logic to find broadcast*/ + const int a_is_broadcasted = !out.sizes().equals(a.sizes()); + const int b_is_broadcasted = !out.sizes().equals(b.sizes()); + const int broadcast = (a_is_broadcasted || b_is_broadcasted); + + const float* const a_data = a.const_data_ptr(); + const float* const b_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + if(broadcast == 1) + { + int out_shape[NNLIB_MAX_DIM]; + int inp1_shape[NNLIB_MAX_DIM]; + int inp2_shape[NNLIB_MAX_DIM]; + + for(int i = 0; i < NNLIB_MAX_DIM; i++) + { + out_shape[i] = 1; + inp1_shape[i] = 1; + inp2_shape[i] = 1; + } + + int off_o = NNLIB_MAX_DIM - out_dim; + int off_a = NNLIB_MAX_DIM - a_dim; + int off_b = NNLIB_MAX_DIM - b_dim; + for(int i = 0; i < out_dim; i++) + out_shape[i+off_o] = out.size(i); + for(int i = 0; i < a_dim; i++) + inp1_shape[i+off_a] = a.size(i); + for(int i = 0; i < b_dim; i++) + inp2_shape[i+off_b] = b.size(i); + + xa_nn_elm_sub_broadcast_4D_f32xf32_f32(out_data, out_shape, a_data, inp1_shape,b_data, inp2_shape); + } + else + { + xa_nn_elm_sub_f32xf32_f32(out_data, a_data, b_data, out.numel()); + } + + } + else + { + + ET_SWITCH_REALH_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALH_TYPES(b_type, ctx, name, CTYPE_B, [&]() { + using CTYPE_IN = typename torch::executor:: + promote_types::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + CTYPE_IN alpha_val; + utils::extract_scalar(alpha, &alpha_val); + ET_SWITCH_REALH_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { + SubInner< + can_cast::value, + CTYPE_A, + CTYPE_B, + CTYPE_IN, + CTYPE_OUT>::run(a, b, alpha_val, out); + }); + }); + }); + } + + return out; +} + +Tensor& sub_scalar_out( + RuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out) { + (void)ctx; + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor(out, a.sizes()) == Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out); + + ScalarType a_type = a.scalar_type(); + ScalarType b_type = utils::get_scalar_dtype(b); + ScalarType alpha_type = utils::get_scalar_dtype(alpha); + ScalarType common_type = + utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false); + ScalarType out_type = out.scalar_type(); + + ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); + ET_KERNEL_CHECK(ctx, canCast(alpha_type, common_type), InvalidArgument, out); + + if (common_type == ScalarType::Half) { + common_type = ScalarType::Float; + } + + constexpr auto name = "sub.Scalar_out"; + + ET_SWITCH_REALH_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_SCALAR_OBJ_REAL_TYPES(b_type, ctx, name, CTYPE_B, [&]() { + using CTYPE_IN = typename utils::promote_type_with_scalar_type< + CTYPE_A, + CTYPE_B, + /*half_to_float*/ true>::type; + ET_DCHECK(CppTypeToScalarType::value == common_type); + + CTYPE_B b_val; + utils::extract_scalar(b, &b_val); + CTYPE_IN b_casted = static_cast(b_val); + + CTYPE_IN alpha_val; + utils::extract_scalar(alpha, &alpha_val); + + using CTYPE_OUT = typename std::conditional< + std::is_same::value, + internal::F2, + CTYPE_IN>::type; + + apply_unary_map_fn( + [b_casted, alpha_val](const CTYPE_A val_a) { + CTYPE_IN a_casted = static_cast(val_a); + CTYPE_IN value = a_casted - alpha_val * b_casted; + return static_cast(value); + }, + a.const_data_ptr(), + out.mutable_data_ptr(), + out.numel()); + }); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/backends/cadence/hifi/operators/op_view_copy.cpp b/backends/cadence/hifi/operators/op_view_copy.cpp new file mode 100644 index 0000000000..f7174caac1 --- /dev/null +++ b/backends/cadence/hifi/operators/op_view_copy.cpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; + +// view_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!) +Tensor& view_copy_out( + RuntimeContext& ctx, + const Tensor& self, + exec_aten::ArrayRef size_int64_t, + Tensor& out) { + (void)ctx; + + Tensor::SizesType expected_output_size[16]; + ET_KERNEL_CHECK( + ctx, + get_view_copy_target_size( + self, size_int64_t, out.dim(), expected_output_size), + InvalidArgument, + out); + + // Resize for dynamic shape + ET_KERNEL_CHECK_MSG( + ctx, + resize_tensor( + out, {expected_output_size, static_cast(out.dim())}) == + Error::Ok, + InvalidArgument, + out, + "Failed to resize output tensor."); + + ET_KERNEL_CHECK( + ctx, check_view_copy_args(self, size_int64_t, out), InvalidArgument, out); + + if (self.nbytes() > 0) { + memcpy(out.mutable_data_ptr(), self.const_data_ptr(), self.nbytes()); + } + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c new file mode 100644 index 0000000000..cbcdec8811 --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c @@ -0,0 +1,241 @@ +#include "xa_nnlib_common.h" +#include "stdio.h" +/* + * Currently only supports upto 5D input tensors. + * 1/2/3/4 D input tensors will be scaled up to 5D. + * For example, 2x3 -> 1x1x1x2x3. + */ + +WORD32 xa_nn_transpose_32_32(WORD32 * __restrict__ p_out + ,const WORD32 *const p_out_shape + ,const WORD32 * __restrict__ p_inp + ,const WORD32 *const p_inp_shape + ,const WORD32 * __restrict__ p_permute_vec + ,WORD32 num_out_dims + ,WORD32 num_inp_dims) +{ + /* NULL pointer checks */ + XA_NNLIB_ARG_CHK_PTR(p_out, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp, -1); + XA_NNLIB_ARG_CHK_PTR(p_permute_vec, -1); + XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp_shape, -1); + + /* Invalid input checks */ + XA_NNLIB_ARG_CHK_COND(((num_inp_dims <= 0) || (num_inp_dims > 5)), -1); + XA_NNLIB_ARG_CHK_COND((num_out_dims != num_inp_dims), -1); + + int itr = 0; + for(itr=0; itr < num_inp_dims; itr++) + { + XA_NNLIB_ARG_CHK_COND((p_inp_shape[itr] <= 0), -1); + } + for(itr=0; itr < num_out_dims; itr++) + { + XA_NNLIB_ARG_CHK_COND((p_out_shape[itr] <= 0), -1); + } + + + /* Output shape provided must be correct based on input + * shape and permute values */ + for(itr=0; itr < num_out_dims; itr++) + { + int output_dim = p_out_shape[itr]; + int expected_dim = p_inp_shape[p_permute_vec[itr]]; + XA_NNLIB_ARG_CHK_COND((output_dim != expected_dim), -1); + } + + /* Pointer alignment checks */ + XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_permute_vec, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp_shape, sizeof(WORD32), -1); + + /* Shift all dim with 1 in the outer part */ + int eff_output_shape[5]; + int eff_permute_vec[5]; + + for(int i = 0; i < num_out_dims; i++) + { + eff_output_shape[i] = p_out_shape[i]; + eff_permute_vec[i] = p_permute_vec[i]; + } + + int one_i=num_out_dims-1, non_one_i=num_out_dims-1; + while(one_i > 0 && non_one_i >=0){ + while(one_i > 0 && eff_output_shape[one_i]!=1){ + one_i--; + } + non_one_i = one_i; + while(non_one_i >= 0 && eff_output_shape[non_one_i]==1) + { + non_one_i--; + } + if(one_i > 0 && non_one_i >=0){ + int temp; + /*swap output_shape*/ + { + temp = eff_output_shape[one_i]; + eff_output_shape[one_i] = eff_output_shape[non_one_i]; + eff_output_shape[non_one_i] = temp; + } + /*swap permute_vec*/ + { + temp = eff_permute_vec[one_i]; + eff_permute_vec[one_i] = eff_permute_vec[non_one_i]; + eff_permute_vec[non_one_i] = temp; + } + + } + } + + /* Promoting lesser dim tensors to 5D tensors. + * Also updating the permute_vec and shapes as needed for optimization */ + int p_5D_inp_shape[5] = {1, 1, 1, 1, 1}; + int p_5D_out_shape[5] = {1, 1, 1, 1, 1}; + int p_5D_permute_vec[5] = {0, 1, 2, 3, 4}; + + /* Check if any inner inp dimension is same in the output */ + int last_dim_same = 1, last_n_same_dim = 0; + itr = num_inp_dims - 1; + while(itr >= 0) + { + last_n_same_dim = (last_dim_same && (eff_permute_vec[itr] == itr)) ? (last_n_same_dim + 1) : last_n_same_dim; + last_dim_same = (eff_permute_vec[itr] == itr) ? last_dim_same & 1 : last_dim_same & 0; + itr--; + } + + int dims_added = 5 - num_inp_dims; + itr = num_inp_dims - 1; + int same_count = last_n_same_dim; + int count = 4; + while(itr >= 0) + { + p_5D_inp_shape[count] = (same_count > 0) ? p_5D_inp_shape[count]*p_inp_shape[itr] : p_inp_shape[itr]; + p_5D_out_shape[count] = (same_count > 0) ? p_5D_out_shape[count]*eff_output_shape[itr] : eff_output_shape[itr]; + same_count--; + itr--; + count = (same_count > 0) ? count : count - 1; + } + + itr = num_inp_dims - 1; + same_count = (last_n_same_dim) ? num_inp_dims - (last_n_same_dim - 1) : 0; + count = 4; + while(itr >= 0) + { + p_5D_permute_vec[count] = (same_count > 0) ? eff_permute_vec[itr-(last_n_same_dim - 1)] + dims_added + last_n_same_dim - 1 : eff_permute_vec[itr] + dims_added; + same_count--; + itr--; + count--; + } + + int out_dim0, out_dim1, out_dim2, out_dim3, out_dim4; + int inp_dim1, inp_dim2, inp_dim3, inp_dim4; + int inp_stride[5]; + + out_dim0 = p_5D_out_shape[0]; + out_dim1 = p_5D_out_shape[1]; + out_dim2 = p_5D_out_shape[2]; + out_dim3 = p_5D_out_shape[3]; + out_dim4 = p_5D_out_shape[4]; + + inp_dim1 = p_5D_inp_shape[1]; + inp_dim2 = p_5D_inp_shape[2]; + inp_dim3 = p_5D_inp_shape[3]; + inp_dim4 = p_5D_inp_shape[4]; + + inp_stride[0] = inp_dim1*inp_dim2*inp_dim3*inp_dim4; + inp_stride[1] = inp_dim2*inp_dim3*inp_dim4; + inp_stride[2] = inp_dim3*inp_dim4; + inp_stride[3] = inp_dim4; + inp_stride[4] = 1; + + if(last_n_same_dim) + { + int itr0, itr1, itr2, itr3, itr4; + WORD32 *p_inp0 = (WORD32 *)p_inp; + for(itr0 = 0; itr0 < out_dim0; itr0++) + { + WORD32 *p_inp1 = p_inp0+(itr0*inp_stride[p_5D_permute_vec[0]]); +#pragma loop_count min=1 + for(itr1 = 0; itr1 < out_dim1; itr1++) + { + WORD32 *p_inp2 = p_inp1+(itr1*inp_stride[p_5D_permute_vec[1]]); +#pragma loop_count min=1 + for(itr2 = 0; itr2 < out_dim2; itr2++) + { + WORD32 *p_inp3 = p_inp2+(itr2*inp_stride[p_5D_permute_vec[2]]); +#pragma loop_count min=1 + for(itr3 = 0; itr3 < out_dim3; itr3++, p_out+=out_dim4) + { + WORD32 *p_inp4 = p_inp3+(itr3*inp_stride[p_5D_permute_vec[3]]); + ae_int32x2 *__restrict__ pae_i = (ae_int32x2 *)(p_inp4); + ae_int32x2 *__restrict__ pae_o = (ae_int32x2 *)(p_out); + ae_valign a_inp = AE_LA64_PP(pae_i); + ae_valign a_out = AE_ZALIGN64(); + ae_int32x2 d0; + for(itr4 = 0; itr4 < (out_dim4 >> 1); itr4++) + { + AE_LA32X2_IP(d0, a_inp, pae_i); + AE_SA32X2_IP(d0, a_out, pae_o); + } + AE_SA64POS_FP(a_out, pae_o); + ae_int32 *__restrict__ puae_i = (ae_int32 *)(pae_i); + ae_int32 *__restrict__ puae_o = (ae_int32 *)(pae_o); +#pragma loop_count max=3 + for(itr4 = 0; itr4 < (out_dim4 & 1); itr4++) + { + puae_o[itr4] = puae_i[itr4]; + } + } + } + } + } + } + else + { + int itr0, itr1, itr2, itr3, itr4; + WORD32 *p_inp0 = (WORD32 *)p_inp; + for(itr0 = 0; itr0 < out_dim0; itr0++) + { + WORD32 *p_inp1 = p_inp0+(itr0*inp_stride[p_5D_permute_vec[0]]); + for(itr1 = 0; itr1 < out_dim1; itr1++) + { + WORD32 *p_inp2 = p_inp1+(itr1*inp_stride[p_5D_permute_vec[1]]); + for(itr2 = 0; itr2 < out_dim2; itr2++) + { + WORD32 *p_inp3 = p_inp2+(itr2*inp_stride[p_5D_permute_vec[2]]); + for(itr3 = 0; itr3 < out_dim3; itr3++) + { + WORD32 *p_inp4 = p_inp3+(itr3*inp_stride[p_5D_permute_vec[3]]); + + ae_valign a_out = AE_ZALIGN64(); + for(itr4 = 0; itr4 < (out_dim4 >> 1); itr4++) + { + ae_int32x2 d0, d1; + ae_int32x2 tmp0; + + d0 = AE_L32_X((ae_int32 *)p_inp4, 0); + p_inp4 += inp_stride[p_5D_permute_vec[4]]; + d1 = AE_L32_X((ae_int32 *)p_inp4, 0); + p_inp4 += inp_stride[p_5D_permute_vec[4]]; + + tmp0 = AE_SEL32_HH(d0, d1); + + AE_SA32X2_IP(tmp0, a_out, (ae_int32x2 *)p_out); + } + AE_SA64POS_FP(a_out, p_out); +#pragma loop_count max=3 + for(itr4 = 0; itr4 < (out_dim4 & 1); itr4++) + { + *p_out++ = *p_inp4; + } + } + } + } + } + } + + return 0; +} \ No newline at end of file