forked from cad-audio/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding cat operator kernel optimization (#20)
* Adding cat operator kernel optimization * Adding cat operator kernel optimization --------- Co-authored-by: dijopaul <[email protected]>
- Loading branch information
Showing
6 changed files
with
344 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h> | ||
#include <executorch/runtime/kernel/kernel_includes.h> | ||
#include <cstring> | ||
|
||
#include <executorch/backends/cadence/hifi/kernels/kernels.h> | ||
|
||
using exec_aten::ScalarType; | ||
using exec_aten::Tensor; | ||
using executorch::aten::RuntimeContext; | ||
using torch::executor::Error; | ||
|
||
namespace impl { | ||
namespace HiFi { | ||
namespace native { | ||
|
||
Tensor& cat_out( | ||
RuntimeContext& ctx, | ||
exec_aten::ArrayRef<Tensor> tensors, | ||
int64_t dim, | ||
Tensor& out) { | ||
constexpr auto name = "cat.out"; | ||
constexpr int kNnlibMaxDim = 16; | ||
|
||
bool optimized = 1; | ||
|
||
if (out.scalar_type() != ScalarType::Float) | ||
optimized = 0; | ||
|
||
if (in.dim() > kNnlibMaxDim) | ||
optimized = 0; | ||
|
||
if (optimized) { | ||
WORD32 num_inp = tensors.size(); | ||
WORD32 num_inp_dims = out.dim(); | ||
WORD32 num_out_dims = num_inp_dims; | ||
WORD32 axis = dim; | ||
|
||
WORD32 inp_shape[kNnlibMaxDim][kNnlibMaxDim]; | ||
WORD32 p_out_shape[kNnlibMaxDim]; | ||
|
||
WORD32* ptr_shape[kNnlibMaxDim]; | ||
const WORD32* ptr[kNnlibMaxDim]; | ||
|
||
int k = 0; | ||
for (int i = 0; i < num_inp; i++) { | ||
if (tensors[i].numel() == 0) | ||
continue; | ||
ptr[k] = (const WORD32*)tensors[i].const_data_ptr<float>(); | ||
for (int j = 0; j < num_inp_dims; j++) { | ||
inp_shape[k][j] = tensors[i].size(j); | ||
} | ||
ptr_shape[k] = inp_shape[k]; | ||
k++; | ||
} | ||
|
||
num_inp = k; | ||
|
||
for (int i = 0; i < num_out_dims; i++) { | ||
p_out_shape[i] = out.size(i); | ||
} | ||
|
||
const WORD32** pp_inps = &ptr[0]; | ||
|
||
WORD32* p_out = (WORD32*)out.mutable_data_ptr<float>(); | ||
|
||
const WORD32* const* pp_inps_shape = (const WORD32* const*)&ptr_shape[0]; | ||
|
||
WORD32 val = xa_nn_concat_32_32( | ||
p_out, | ||
p_out_shape, | ||
pp_inps, | ||
pp_inps_shape, | ||
num_out_dims, | ||
num_inp, | ||
num_inp_dims, | ||
axis); | ||
|
||
return out; | ||
} | ||
|
||
if (dim < 0) { | ||
dim += out.dim(); | ||
} | ||
|
||
ET_KERNEL_CHECK(ctx, check_cat_args(tensors, dim, out), InvalidArgument, out); | ||
|
||
Tensor::SizesType | ||
expected_out_size[executorch::runtime::kTensorDimensionLimit]; | ||
size_t expected_out_dim = 0; | ||
get_cat_out_target_size(tensors, dim, expected_out_size, &expected_out_dim); | ||
|
||
ET_KERNEL_CHECK( | ||
ctx, | ||
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok, | ||
InvalidArgument, | ||
out); | ||
|
||
// Special handling when all inputs are 1D-empty tensors for aten consistency | ||
// In that case, just return an 1D-empty tensor without checking dim | ||
bool all_1d_empty = true; | ||
for (size_t i = 0; i < tensors.size(); ++i) { | ||
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) { | ||
all_1d_empty = false; | ||
break; | ||
} | ||
} | ||
if (all_1d_empty) { | ||
return out; | ||
} | ||
|
||
const size_t outer = getLeadingDims(out, dim); | ||
const size_t dim_stride = getTrailingDims(out, dim); | ||
const size_t ninputs = tensors.size(); | ||
|
||
const auto out_type = out.scalar_type(); | ||
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { | ||
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>(); | ||
for (size_t i = 0; i < outer; ++i) { | ||
for (size_t j = 0; j < ninputs; ++j) { | ||
const auto in_type = tensors[j].scalar_type(); | ||
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { | ||
if (tensors[j].numel() == 0) { | ||
return; | ||
} | ||
size_t inner = tensors[j].size(dim) * dim_stride; | ||
const CTYPE_IN* const in_ptr = | ||
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner; | ||
|
||
for (size_t k = 0; k < inner; ++k) { | ||
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]); | ||
} | ||
out_ptr += inner; | ||
}); | ||
} | ||
} | ||
}); | ||
|
||
return out; | ||
} | ||
|
||
} // namespace native | ||
} // namespace HiFi | ||
} // namespace impl |
172 changes: 172 additions & 0 deletions
172
backends/cadence/hifi/third-party/nnlib/xa_nn_concat_32.c
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
#include "xa_type_def.h" | ||
#include "xa_nn_common.h" | ||
#include "xa_nnlib_kernels_api.h" | ||
#include "xa_nnlib_common_macros.h" | ||
#include "xa_nnlib_err_chk.h" | ||
#include "xa_nnlib_common.h" | ||
|
||
WORD32 xa_nn_concat_32_32(WORD32 * __restrict__ p_out | ||
,const WORD32 *const p_out_shape | ||
,const WORD32 **pp_inps | ||
,const WORD32 *const *pp_inps_shape | ||
,WORD32 num_out_dims | ||
,WORD32 num_inp | ||
,WORD32 num_inp_dims | ||
,WORD32 axis) | ||
{ | ||
XA_NNLIB_ARG_CHK_PTR(p_out, -1); | ||
XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1); | ||
XA_NNLIB_ARG_CHK_PTR(pp_inps, -1); | ||
XA_NNLIB_ARG_CHK_PTR(pp_inps_shape, -1); | ||
/* Pointer alignment checks */ | ||
XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1); | ||
XA_NNLIB_ARG_CHK_ALIGN(pp_inps, sizeof(WORD32 *), -1); | ||
XA_NNLIB_ARG_CHK_ALIGN(pp_inps_shape, sizeof(WORD32 *), -1); | ||
//Validate Arguments | ||
XA_NNLIB_ARG_CHK_COND((num_out_dims <= 0 || num_out_dims > 6), -1); | ||
XA_NNLIB_ARG_CHK_COND((num_inp <= 0 || num_inp > 10), -1); | ||
XA_NNLIB_ARG_CHK_COND((num_inp_dims != num_out_dims), -1); | ||
XA_NNLIB_ARG_CHK_COND((axis < -num_out_dims || axis >= num_out_dims), -1); | ||
|
||
int i = 0, j = 0; | ||
for(i = 0; i < num_out_dims; i++) | ||
{ | ||
XA_NNLIB_ARG_CHK_COND((p_out_shape[i] <= 0), -1); | ||
} | ||
|
||
if(axis < 0) | ||
axis = num_out_dims + axis; | ||
|
||
WORD32 concat_size = 0; | ||
for (i = 0; i < num_inp; i++) | ||
{ | ||
XA_NNLIB_ARG_CHK_PTR(pp_inps[i], -1); | ||
XA_NNLIB_ARG_CHK_PTR(pp_inps_shape[i], -1); | ||
XA_NNLIB_ARG_CHK_ALIGN(pp_inps_shape[i], sizeof(WORD32), -1); | ||
#pragma loop_count min=1 | ||
for(j = 0; j < num_out_dims; j++) | ||
{ | ||
XA_NNLIB_ARG_CHK_COND((pp_inps_shape[i][j] != p_out_shape[j] && j != axis), -1); | ||
} | ||
|
||
XA_NNLIB_ARG_CHK_COND((pp_inps_shape[i][axis] <= 0), -1); | ||
concat_size += pp_inps_shape[i][axis]; | ||
} | ||
|
||
XA_NNLIB_ARG_CHK_COND((p_out_shape[axis] != concat_size), -1); | ||
|
||
//Calculate outer and inner size for axis | ||
WORD32 outer_size = 1; | ||
#pragma no_simd | ||
for(int i = 0; i < axis; i++) | ||
{ | ||
outer_size *= p_out_shape[i]; | ||
} | ||
|
||
WORD32 base_inner_size = 1; | ||
#pragma no_simd | ||
for(int i = axis + 1; i < num_out_dims; i++) | ||
{ | ||
base_inner_size *= p_out_shape[i]; | ||
} | ||
|
||
WORD32 *ptmp_out = p_out; | ||
for(int i = 0; i < num_inp; i++) | ||
{ | ||
const WORD32 copy_size = pp_inps_shape[i][axis] * base_inner_size; | ||
WORD32 *output_ptr = ptmp_out; | ||
const WORD32* input_ptr = pp_inps[i]; | ||
|
||
if(((copy_size & 1) == 0) && (((concat_size * base_inner_size) & 1) == 0) | ||
&& (((unsigned)input_ptr & 1) == 0) && (((unsigned)output_ptr & 1) == 0)) | ||
{ | ||
if(copy_size <= 8) | ||
{ | ||
const ae_f32 *pae_inp = (const ae_f32 *)input_ptr; | ||
for(int k = 0; k < outer_size; k++) | ||
{ | ||
ae_f32 *pae_out = (ae_f32 *)output_ptr; | ||
#pragma concurrent | ||
#pragma no_simd | ||
for(int ic = 0; ic < copy_size; ic++) | ||
{ | ||
*pae_out++ = *pae_inp++; | ||
} | ||
output_ptr += concat_size * base_inner_size; | ||
} | ||
} | ||
else | ||
{ | ||
for(int k = 0; k < outer_size; k++) | ||
{ | ||
const ae_int32x2 *pae_inp = (const ae_int32x2 *)input_ptr; | ||
ae_int32x2 *pae_out = (ae_int32x2 *)output_ptr; | ||
ae_valign inp_a, out_a; | ||
inp_a = AE_LA64_PP(pae_inp); | ||
out_a = AE_ZALIGN64(); | ||
for(int ic = 0; ic < (copy_size >> 1); ic++) | ||
{ | ||
ae_int32x2 d0; | ||
AE_LA32X2_IP(d0, inp_a, pae_inp); | ||
AE_SA32X2_IP(d0, out_a, pae_out); | ||
} | ||
AE_SA64POS_FP(out_a, pae_out); | ||
const ae_f32 *puae_inp = (const ae_f32 *)pae_inp; | ||
ae_f32 *puae_out = (ae_f32 *)pae_out; | ||
#pragma concurrent | ||
for(int ic = 0; ic < (copy_size & 1); ic++) | ||
{ | ||
puae_out[copy_size - 1] = puae_inp[copy_size - 1]; | ||
} | ||
input_ptr += copy_size; | ||
output_ptr += concat_size * base_inner_size; | ||
} | ||
} | ||
} | ||
else | ||
{ | ||
if(copy_size <= 6) | ||
{ | ||
for(int k = 0; k < outer_size; k++) | ||
{ | ||
#pragma concurrent | ||
#pragma no_unroll | ||
for(int ic = 0; ic < copy_size; ic++) | ||
{ | ||
output_ptr[ic] = *input_ptr++; | ||
} | ||
output_ptr += concat_size * base_inner_size; | ||
} | ||
} | ||
else | ||
{ | ||
for(int k = 0; k < outer_size; k++) | ||
{ | ||
const ae_int32x2 *pae_inp = (const ae_int32x2 *)input_ptr; | ||
ae_int32x2 *pae_out = (ae_int32x2 *)output_ptr; | ||
ae_valign inp_a, out_a; | ||
inp_a = AE_LA64_PP(pae_inp); | ||
out_a = AE_ZALIGN64(); | ||
|
||
#pragma concurrent | ||
for(int ic = 0; ic < copy_size >> 1; ic++) | ||
{ | ||
ae_int32x2 d0; | ||
AE_LA32X2_IP(d0, inp_a, pae_inp); | ||
AE_SA32X2_IP(d0, out_a, pae_out); | ||
} | ||
AE_SA64POS_FP(out_a, pae_out); | ||
|
||
for(int ic = 0; ic < (copy_size & 1); ic++) | ||
{ | ||
output_ptr[copy_size - 1] = input_ptr[copy_size - 1]; | ||
} | ||
input_ptr += copy_size; | ||
output_ptr += concat_size * base_inner_size; | ||
} | ||
} | ||
} | ||
ptmp_out += copy_size; | ||
} | ||
return 0; | ||
} |