Skip to content

Commit

Permalink
Adding cat operator kernel optimization (#20)
Browse files Browse the repository at this point in the history
* Adding cat operator kernel optimization

* Adding cat operator kernel optimization

---------

Co-authored-by: dijopaul <[email protected]>
  • Loading branch information
Rushi-cad and dijopaul authored Oct 24, 2024
1 parent 6eff57b commit ae7b6bc
Show file tree
Hide file tree
Showing 6 changed files with 344 additions and 9 deletions.
2 changes: 1 addition & 1 deletion backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
- op: cat.out
kernels:
- arg_meta: null
kernel_name: torch::executor::cat_out
kernel_name: impl::HiFi::cat_out

- op: clamp.Tensor_out
kernels:
Expand Down
1 change: 1 addition & 0 deletions backends/cadence/hifi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ add_library(
kernels.cpp
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/matmul_asym8uxasym8u_asym8u.cpp
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_broadcast_32.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_concat_32.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_add_f32_broadcast.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_atan2_f32.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_clamp_f32_broadcast.c
Expand Down
25 changes: 18 additions & 7 deletions backends/cadence/hifi/kernels/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,16 @@

/* Potential NNLIB function/APIs */

extern "C" WORD32 xa_nn_concat_32_32(
WORD32 * __restrict__ p_out,
const WORD32 *const p_out_shape,
const WORD32 **pp_inps,
const WORD32 *const *pp_inps_shape,
WORD32 num_out_dims,
WORD32 num_inp,
WORD32 num_inp_dims,
WORD32 axis);

extern "C" WORD32 xa_nn_broadcast_32_32(
WORD32* __restrict__ p_out,
const int* const out_shape,
Expand Down Expand Up @@ -157,13 +167,14 @@ extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32(
const unsigned char* __restrict__ p_condition,
const WORD32* const p_condition_shape);

extern "C" WORD32 xa_nn_transpose_32_32(WORD32 * __restrict__ p_out,
const WORD32 *const p_out_shape,
const WORD32 * __restrict__ p_inp,
const WORD32 *const p_inp_shape,
const WORD32 * __restrict__ p_permute_vec,
WORD32 num_out_dims,
WORD32 num_inp_dims);
extern "C" WORD32 xa_nn_transpose_32_32(
WORD32 * __restrict__ p_out,
const WORD32 *const p_out_shape,
const WORD32 * __restrict__ p_inp,
const WORD32 *const p_inp_shape,
const WORD32 * __restrict__ p_permute_vec,
WORD32 num_out_dims,
WORD32 num_inp_dims);

namespace impl {
namespace HiFi {
Expand Down
2 changes: 1 addition & 1 deletion backends/cadence/hifi/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_atan2.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_bmm.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_cat.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_clamp.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_eq.cpp"
Expand All @@ -42,7 +43,6 @@ set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_full.cpp"
Expand Down
151 changes: 151 additions & 0 deletions backends/cadence/hifi/operators/op_cat.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <cstring>

#include <executorch/backends/cadence/hifi/kernels/kernels.h>

using exec_aten::ScalarType;
using exec_aten::Tensor;
using executorch::aten::RuntimeContext;
using torch::executor::Error;

namespace impl {
namespace HiFi {
namespace native {

Tensor& cat_out(
RuntimeContext& ctx,
exec_aten::ArrayRef<Tensor> tensors,
int64_t dim,
Tensor& out) {
constexpr auto name = "cat.out";
constexpr int kNnlibMaxDim = 16;

bool optimized = 1;

if (out.scalar_type() != ScalarType::Float)
optimized = 0;

if (in.dim() > kNnlibMaxDim)
optimized = 0;

if (optimized) {
WORD32 num_inp = tensors.size();
WORD32 num_inp_dims = out.dim();
WORD32 num_out_dims = num_inp_dims;
WORD32 axis = dim;

WORD32 inp_shape[kNnlibMaxDim][kNnlibMaxDim];
WORD32 p_out_shape[kNnlibMaxDim];

WORD32* ptr_shape[kNnlibMaxDim];
const WORD32* ptr[kNnlibMaxDim];

int k = 0;
for (int i = 0; i < num_inp; i++) {
if (tensors[i].numel() == 0)
continue;
ptr[k] = (const WORD32*)tensors[i].const_data_ptr<float>();
for (int j = 0; j < num_inp_dims; j++) {
inp_shape[k][j] = tensors[i].size(j);
}
ptr_shape[k] = inp_shape[k];
k++;
}

num_inp = k;

for (int i = 0; i < num_out_dims; i++) {
p_out_shape[i] = out.size(i);
}

const WORD32** pp_inps = &ptr[0];

WORD32* p_out = (WORD32*)out.mutable_data_ptr<float>();

const WORD32* const* pp_inps_shape = (const WORD32* const*)&ptr_shape[0];

WORD32 val = xa_nn_concat_32_32(
p_out,
p_out_shape,
pp_inps,
pp_inps_shape,
num_out_dims,
num_inp,
num_inp_dims,
axis);

return out;
}

if (dim < 0) {
dim += out.dim();
}

ET_KERNEL_CHECK(ctx, check_cat_args(tensors, dim, out), InvalidArgument, out);

Tensor::SizesType
expected_out_size[executorch::runtime::kTensorDimensionLimit];
size_t expected_out_dim = 0;
get_cat_out_target_size(tensors, dim, expected_out_size, &expected_out_dim);

ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
InvalidArgument,
out);

// Special handling when all inputs are 1D-empty tensors for aten consistency
// In that case, just return an 1D-empty tensor without checking dim
bool all_1d_empty = true;
for (size_t i = 0; i < tensors.size(); ++i) {
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) {
all_1d_empty = false;
break;
}
}
if (all_1d_empty) {
return out;
}

const size_t outer = getLeadingDims(out, dim);
const size_t dim_stride = getTrailingDims(out, dim);
const size_t ninputs = tensors.size();

const auto out_type = out.scalar_type();
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>();
for (size_t i = 0; i < outer; ++i) {
for (size_t j = 0; j < ninputs; ++j) {
const auto in_type = tensors[j].scalar_type();
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] {
if (tensors[j].numel() == 0) {
return;
}
size_t inner = tensors[j].size(dim) * dim_stride;
const CTYPE_IN* const in_ptr =
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner;

for (size_t k = 0; k < inner; ++k) {
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]);
}
out_ptr += inner;
});
}
}
});

return out;
}

} // namespace native
} // namespace HiFi
} // namespace impl
172 changes: 172 additions & 0 deletions backends/cadence/hifi/third-party/nnlib/xa_nn_concat_32.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#include "xa_type_def.h"
#include "xa_nn_common.h"
#include "xa_nnlib_kernels_api.h"
#include "xa_nnlib_common_macros.h"
#include "xa_nnlib_err_chk.h"
#include "xa_nnlib_common.h"

WORD32 xa_nn_concat_32_32(WORD32 * __restrict__ p_out
,const WORD32 *const p_out_shape
,const WORD32 **pp_inps
,const WORD32 *const *pp_inps_shape
,WORD32 num_out_dims
,WORD32 num_inp
,WORD32 num_inp_dims
,WORD32 axis)
{
XA_NNLIB_ARG_CHK_PTR(p_out, -1);
XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1);
XA_NNLIB_ARG_CHK_PTR(pp_inps, -1);
XA_NNLIB_ARG_CHK_PTR(pp_inps_shape, -1);
/* Pointer alignment checks */
XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1);
XA_NNLIB_ARG_CHK_ALIGN(pp_inps, sizeof(WORD32 *), -1);
XA_NNLIB_ARG_CHK_ALIGN(pp_inps_shape, sizeof(WORD32 *), -1);
//Validate Arguments
XA_NNLIB_ARG_CHK_COND((num_out_dims <= 0 || num_out_dims > 6), -1);
XA_NNLIB_ARG_CHK_COND((num_inp <= 0 || num_inp > 10), -1);
XA_NNLIB_ARG_CHK_COND((num_inp_dims != num_out_dims), -1);
XA_NNLIB_ARG_CHK_COND((axis < -num_out_dims || axis >= num_out_dims), -1);

int i = 0, j = 0;
for(i = 0; i < num_out_dims; i++)
{
XA_NNLIB_ARG_CHK_COND((p_out_shape[i] <= 0), -1);
}

if(axis < 0)
axis = num_out_dims + axis;

WORD32 concat_size = 0;
for (i = 0; i < num_inp; i++)
{
XA_NNLIB_ARG_CHK_PTR(pp_inps[i], -1);
XA_NNLIB_ARG_CHK_PTR(pp_inps_shape[i], -1);
XA_NNLIB_ARG_CHK_ALIGN(pp_inps_shape[i], sizeof(WORD32), -1);
#pragma loop_count min=1
for(j = 0; j < num_out_dims; j++)
{
XA_NNLIB_ARG_CHK_COND((pp_inps_shape[i][j] != p_out_shape[j] && j != axis), -1);
}

XA_NNLIB_ARG_CHK_COND((pp_inps_shape[i][axis] <= 0), -1);
concat_size += pp_inps_shape[i][axis];
}

XA_NNLIB_ARG_CHK_COND((p_out_shape[axis] != concat_size), -1);

//Calculate outer and inner size for axis
WORD32 outer_size = 1;
#pragma no_simd
for(int i = 0; i < axis; i++)
{
outer_size *= p_out_shape[i];
}

WORD32 base_inner_size = 1;
#pragma no_simd
for(int i = axis + 1; i < num_out_dims; i++)
{
base_inner_size *= p_out_shape[i];
}

WORD32 *ptmp_out = p_out;
for(int i = 0; i < num_inp; i++)
{
const WORD32 copy_size = pp_inps_shape[i][axis] * base_inner_size;
WORD32 *output_ptr = ptmp_out;
const WORD32* input_ptr = pp_inps[i];

if(((copy_size & 1) == 0) && (((concat_size * base_inner_size) & 1) == 0)
&& (((unsigned)input_ptr & 1) == 0) && (((unsigned)output_ptr & 1) == 0))
{
if(copy_size <= 8)
{
const ae_f32 *pae_inp = (const ae_f32 *)input_ptr;
for(int k = 0; k < outer_size; k++)
{
ae_f32 *pae_out = (ae_f32 *)output_ptr;
#pragma concurrent
#pragma no_simd
for(int ic = 0; ic < copy_size; ic++)
{
*pae_out++ = *pae_inp++;
}
output_ptr += concat_size * base_inner_size;
}
}
else
{
for(int k = 0; k < outer_size; k++)
{
const ae_int32x2 *pae_inp = (const ae_int32x2 *)input_ptr;
ae_int32x2 *pae_out = (ae_int32x2 *)output_ptr;
ae_valign inp_a, out_a;
inp_a = AE_LA64_PP(pae_inp);
out_a = AE_ZALIGN64();
for(int ic = 0; ic < (copy_size >> 1); ic++)
{
ae_int32x2 d0;
AE_LA32X2_IP(d0, inp_a, pae_inp);
AE_SA32X2_IP(d0, out_a, pae_out);
}
AE_SA64POS_FP(out_a, pae_out);
const ae_f32 *puae_inp = (const ae_f32 *)pae_inp;
ae_f32 *puae_out = (ae_f32 *)pae_out;
#pragma concurrent
for(int ic = 0; ic < (copy_size & 1); ic++)
{
puae_out[copy_size - 1] = puae_inp[copy_size - 1];
}
input_ptr += copy_size;
output_ptr += concat_size * base_inner_size;
}
}
}
else
{
if(copy_size <= 6)
{
for(int k = 0; k < outer_size; k++)
{
#pragma concurrent
#pragma no_unroll
for(int ic = 0; ic < copy_size; ic++)
{
output_ptr[ic] = *input_ptr++;
}
output_ptr += concat_size * base_inner_size;
}
}
else
{
for(int k = 0; k < outer_size; k++)
{
const ae_int32x2 *pae_inp = (const ae_int32x2 *)input_ptr;
ae_int32x2 *pae_out = (ae_int32x2 *)output_ptr;
ae_valign inp_a, out_a;
inp_a = AE_LA64_PP(pae_inp);
out_a = AE_ZALIGN64();

#pragma concurrent
for(int ic = 0; ic < copy_size >> 1; ic++)
{
ae_int32x2 d0;
AE_LA32X2_IP(d0, inp_a, pae_inp);
AE_SA32X2_IP(d0, out_a, pae_out);
}
AE_SA64POS_FP(out_a, pae_out);

for(int ic = 0; ic < (copy_size & 1); ic++)
{
output_ptr[copy_size - 1] = input_ptr[copy_size - 1];
}
input_ptr += copy_size;
output_ptr += concat_size * base_inner_size;
}
}
}
ptmp_out += copy_size;
}
return 0;
}

0 comments on commit ae7b6bc

Please sign in to comment.