Skip to content

Commit

Permalink
Adding permute_copy operator kernel optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
Rushi-cad committed Oct 21, 2024
1 parent 9f169fb commit e2d3785
Showing 1 changed file with 63 additions and 69 deletions.
132 changes: 63 additions & 69 deletions backends/cadence/hifi/operators/op_permute_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/backends/cadence/hifi/kernels/kernels.h>

using exec_aten::SizesType;
using exec_aten::ScalarType;
using exec_aten::SizesType;
using exec_aten::Tensor;
using executorch::runtime::IntArrayRef;
using torch::executor::Error;
using executorch::runtime::KernelRuntimeContext;
using executorch::runtime::kTensorDimensionLimit;
using torch::executor::Error;

namespace impl {
namespace HiFi {
Expand Down Expand Up @@ -65,109 +65,103 @@ Tensor& permute_copy_out(
out);

const auto in_type = out.scalar_type();

constexpr auto name = "permute_copy.out";
constexpr int kNnlibMaxDim = 16;

bool optimized = 0;
if(out.scalar_type() == ScalarType::Float)

if (out.scalar_type() == ScalarType::Float)
optimized = 1;
else if(out.scalar_type() == ScalarType::Char)
else if (out.scalar_type() == ScalarType::Char)
optimized = 1;
else if(out.scalar_type() == ScalarType::Byte)
else if (out.scalar_type() == ScalarType::Byte)
optimized = 1;

if(in.dim() > kNnlibMaxDim)
if (in.dim() > kNnlibMaxDim)
optimized = 0;

if(optimized){

if(in_type == ScalarType::Float)
{
WORD32 * p_inp = (WORD32 *)in.const_data_ptr<float>();
WORD32 * p_out = (WORD32 *)out.mutable_data_ptr<float>();


if (optimized) {
if (in_type == ScalarType::Float) {
WORD32* p_inp = (WORD32*)in.const_data_ptr<float>();
WORD32* p_out = (WORD32*)out.mutable_data_ptr<float>();

WORD32 num_inp_dims = in.dim();
WORD32 num_out_dims = num_inp_dims;

WORD32 p_inp_shape[kNnlibMaxDim];
WORD32 p_out_shape[kNnlibMaxDim];
WORD32 p_permute_vec[kNnlibMaxDim];

for(int i = 0; i < num_inp_dims; i++)
{

for (int i = 0; i < num_inp_dims; i++) {
p_inp_shape[i] = in.size(i);
p_out_shape[i] = in.size(dims[i]);
p_permute_vec[i] = dims[i];
}

xa_nn_transpose_32_32(p_out,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);


xa_nn_transpose_32_32(
p_out,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);

return out;
}
else if(in_type == ScalarType::Char)
{
WORD8 * p_inp = (WORD8 *)in.const_data_ptr<char>();
WORD8 * p_out = (WORD8 *)out.mutable_data_ptr<char>();

} else if (in_type == ScalarType::Char) {
WORD8* p_inp = (WORD8*)in.const_data_ptr<char>();
WORD8* p_out = (WORD8*)out.mutable_data_ptr<char>();

WORD32 num_inp_dims = in.dim();
WORD32 num_out_dims = num_inp_dims;

WORD32 p_inp_shape[kNnlibMaxDim];
WORD32 p_out_shape[kNnlibMaxDim];
WORD32 p_permute_vec[kNnlibMaxDim];

for(int i = 0; i < num_inp_dims; i++)
{

for (int i = 0; i < num_inp_dims; i++) {
p_inp_shape[i] = in.size(i);
p_out_shape[i] = in.size(dims[i]);
p_permute_vec[i] = dims[i];
}

xa_nn_transpose_8_8(p_out,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);


xa_nn_transpose_8_8(
p_out,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);

return out;
}
else if(in_type == ScalarType::Byte)
{
WORD8 * p_inp = (WORD8 *)in.const_data_ptr<uint8_t>();
WORD8 * p_out = (WORD8 *)out.mutable_data_ptr<uint8_t>();

} else if (in_type == ScalarType::Byte) {
WORD8* p_inp = (WORD8*)in.const_data_ptr<uint8_t>();
WORD8* p_out = (WORD8*)out.mutable_data_ptr<uint8_t>();

WORD32 num_inp_dims = in.dim();
WORD32 num_out_dims = num_inp_dims;

WORD32 p_inp_shape[kNnlibMaxDim];
WORD32 p_out_shape[kNnlibMaxDim];
WORD32 p_permute_vec[kNnlibMaxDim];

for(int i = 0; i < num_inp_dims; i++)
{

for (int i = 0; i < num_inp_dims; i++) {
p_inp_shape[i] = in.size(i);
p_out_shape[i] = in.size(dims[i]);
p_permute_vec[i] = dims[i];
}

xa_nn_transpose_8_8(p_out,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);


xa_nn_transpose_8_8(
p_out,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);

return out;
}
}
Expand Down

0 comments on commit e2d3785

Please sign in to comment.