From e2d3785f28def2f6e28c470d511883fb44077200 Mon Sep 17 00:00:00 2001 From: Rushi-cad Date: Mon, 21 Oct 2024 02:43:22 -0700 Subject: [PATCH] Adding permute_copy operator kernel optimization --- .../hifi/operators/op_permute_copy.cpp | 132 +++++++++--------- 1 file changed, 63 insertions(+), 69 deletions(-) diff --git a/backends/cadence/hifi/operators/op_permute_copy.cpp b/backends/cadence/hifi/operators/op_permute_copy.cpp index d5f30b2b3d..8939979597 100644 --- a/backends/cadence/hifi/operators/op_permute_copy.cpp +++ b/backends/cadence/hifi/operators/op_permute_copy.cpp @@ -6,17 +6,17 @@ * LICENSE file in the root directory of this source tree. */ +#include #include #include -#include -using exec_aten::SizesType; using exec_aten::ScalarType; +using exec_aten::SizesType; using exec_aten::Tensor; using executorch::runtime::IntArrayRef; -using torch::executor::Error; using executorch::runtime::KernelRuntimeContext; using executorch::runtime::kTensorDimensionLimit; +using torch::executor::Error; namespace impl { namespace HiFi { @@ -65,109 +65,103 @@ Tensor& permute_copy_out( out); const auto in_type = out.scalar_type(); - + constexpr auto name = "permute_copy.out"; constexpr int kNnlibMaxDim = 16; - + bool optimized = 0; - - if(out.scalar_type() == ScalarType::Float) + + if (out.scalar_type() == ScalarType::Float) optimized = 1; - else if(out.scalar_type() == ScalarType::Char) + else if (out.scalar_type() == ScalarType::Char) optimized = 1; - else if(out.scalar_type() == ScalarType::Byte) + else if (out.scalar_type() == ScalarType::Byte) optimized = 1; - if(in.dim() > kNnlibMaxDim) + if (in.dim() > kNnlibMaxDim) optimized = 0; - - if(optimized){ - - if(in_type == ScalarType::Float) - { - WORD32 * p_inp = (WORD32 *)in.const_data_ptr(); - WORD32 * p_out = (WORD32 *)out.mutable_data_ptr(); - + + if (optimized) { + if (in_type == ScalarType::Float) { + WORD32* p_inp = (WORD32*)in.const_data_ptr(); + WORD32* p_out = (WORD32*)out.mutable_data_ptr(); + WORD32 num_inp_dims = in.dim(); WORD32 num_out_dims = num_inp_dims; - + WORD32 p_inp_shape[kNnlibMaxDim]; WORD32 p_out_shape[kNnlibMaxDim]; WORD32 p_permute_vec[kNnlibMaxDim]; - - for(int i = 0; i < num_inp_dims; i++) - { + + for (int i = 0; i < num_inp_dims; i++) { p_inp_shape[i] = in.size(i); p_out_shape[i] = in.size(dims[i]); p_permute_vec[i] = dims[i]; } - - xa_nn_transpose_32_32(p_out, - p_out_shape, - p_inp, - p_inp_shape, - p_permute_vec, - num_out_dims, - num_inp_dims); - + + xa_nn_transpose_32_32( + p_out, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + num_out_dims, + num_inp_dims); + return out; - } - else if(in_type == ScalarType::Char) - { - WORD8 * p_inp = (WORD8 *)in.const_data_ptr(); - WORD8 * p_out = (WORD8 *)out.mutable_data_ptr(); - + } else if (in_type == ScalarType::Char) { + WORD8* p_inp = (WORD8*)in.const_data_ptr(); + WORD8* p_out = (WORD8*)out.mutable_data_ptr(); + WORD32 num_inp_dims = in.dim(); WORD32 num_out_dims = num_inp_dims; - + WORD32 p_inp_shape[kNnlibMaxDim]; WORD32 p_out_shape[kNnlibMaxDim]; WORD32 p_permute_vec[kNnlibMaxDim]; - - for(int i = 0; i < num_inp_dims; i++) - { + + for (int i = 0; i < num_inp_dims; i++) { p_inp_shape[i] = in.size(i); p_out_shape[i] = in.size(dims[i]); p_permute_vec[i] = dims[i]; } - - xa_nn_transpose_8_8(p_out, - p_out_shape, - p_inp, - p_inp_shape, - p_permute_vec, - num_out_dims, - num_inp_dims); - + + xa_nn_transpose_8_8( + p_out, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + num_out_dims, + num_inp_dims); + return out; - } - else if(in_type == ScalarType::Byte) - { - WORD8 * p_inp = (WORD8 *)in.const_data_ptr(); - WORD8 * p_out = (WORD8 *)out.mutable_data_ptr(); - + } else if (in_type == ScalarType::Byte) { + WORD8* p_inp = (WORD8*)in.const_data_ptr(); + WORD8* p_out = (WORD8*)out.mutable_data_ptr(); + WORD32 num_inp_dims = in.dim(); WORD32 num_out_dims = num_inp_dims; - + WORD32 p_inp_shape[kNnlibMaxDim]; WORD32 p_out_shape[kNnlibMaxDim]; WORD32 p_permute_vec[kNnlibMaxDim]; - - for(int i = 0; i < num_inp_dims; i++) - { + + for (int i = 0; i < num_inp_dims; i++) { p_inp_shape[i] = in.size(i); p_out_shape[i] = in.size(dims[i]); p_permute_vec[i] = dims[i]; } - - xa_nn_transpose_8_8(p_out, - p_out_shape, - p_inp, - p_inp_shape, - p_permute_vec, - num_out_dims, - num_inp_dims); - + + xa_nn_transpose_8_8( + p_out, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + num_out_dims, + num_inp_dims); + return out; } }