forked from cad-audio/executorch
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding cat, full, permute_copy and relu ops (#34)
* Adding cat, full, permute_copy
- Loading branch information
1 parent
07743ab
commit d730ed8
Showing
10 changed files
with
1,025 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* | ||
* Copyright (c) Meta Platforms, Inc. and affiliates. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the BSD-style license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h> | ||
#include <executorch/runtime/kernel/kernel_includes.h> | ||
#include <cstring> | ||
|
||
#include <executorch/backends/cadence/hifi/kernels/kernels.h> | ||
|
||
using exec_aten::ScalarType; | ||
using exec_aten::Tensor; | ||
using executorch::aten::RuntimeContext; | ||
using executorch::runtime::getLeadingDims; | ||
using executorch::runtime::getTrailingDims; | ||
using executorch::runtime::resize_tensor; | ||
using executorch::runtime::tensors_have_same_dim_order; | ||
using torch::executor::check_cat_args; | ||
using torch::executor::Error; | ||
using torch::executor::get_cat_out_target_size; | ||
|
||
namespace cadence { | ||
namespace impl { | ||
namespace HiFi { | ||
namespace native { | ||
|
||
Tensor& cat_out( | ||
RuntimeContext& ctx, | ||
exec_aten::ArrayRef<Tensor> tensors, | ||
int64_t dim, | ||
Tensor& out) { | ||
constexpr auto name = "cat.out"; | ||
constexpr int kNnlibMaxDim = 16; | ||
|
||
bool optimized = true; | ||
|
||
if (out.scalar_type() != ScalarType::Float) | ||
optimized = false; | ||
|
||
if (optimized) { | ||
WORD32 num_inp = tensors.size(); | ||
WORD32 num_inp_dims = out.dim(); | ||
WORD32 num_out_dims = num_inp_dims; | ||
WORD32 axis = dim; | ||
|
||
WORD32 inp_shape[kNnlibMaxDim][kNnlibMaxDim]; | ||
WORD32 p_out_shape[kNnlibMaxDim]; | ||
|
||
WORD32* ptr_shape[kNnlibMaxDim]; | ||
const WORD32* ptr[kNnlibMaxDim]; | ||
|
||
int k = 0; | ||
for (int i = 0; i < num_inp; i++) { | ||
if (tensors[i].numel() == 0) | ||
continue; | ||
ptr[k] = (const WORD32*)tensors[i].const_data_ptr<float>(); | ||
for (int j = 0; j < num_inp_dims; j++) { | ||
inp_shape[k][j] = tensors[i].size(j); | ||
} | ||
ptr_shape[k] = inp_shape[k]; | ||
k++; | ||
} | ||
|
||
num_inp = k; | ||
|
||
for (int i = 0; i < num_out_dims; i++) { | ||
p_out_shape[i] = out.size(i); | ||
} | ||
|
||
const WORD32** pp_inps = &ptr[0]; | ||
|
||
WORD32* p_out = (WORD32*)out.mutable_data_ptr<float>(); | ||
|
||
const WORD32* const* pp_inps_shape = (const WORD32* const*)&ptr_shape[0]; | ||
|
||
WORD32 ret_val = xa_nn_concat_32_32( | ||
p_out, | ||
p_out_shape, | ||
pp_inps, | ||
pp_inps_shape, | ||
num_out_dims, | ||
num_inp, | ||
num_inp_dims, | ||
axis); | ||
|
||
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out); | ||
|
||
return out; | ||
} | ||
|
||
if (dim < 0) { | ||
dim += out.dim(); | ||
} | ||
|
||
ET_KERNEL_CHECK(ctx, check_cat_args(tensors, dim, out), Internal, out); | ||
|
||
Tensor::SizesType | ||
expected_out_size[executorch::runtime::kTensorDimensionLimit]; | ||
size_t expected_out_dim = 0; | ||
get_cat_out_target_size(tensors, dim, expected_out_size, &expected_out_dim); | ||
|
||
ET_KERNEL_CHECK( | ||
ctx, | ||
resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok, | ||
InvalidArgument, | ||
out); | ||
|
||
// Special handling when all inputs are 1D-empty tensors for aten consistency | ||
// In that case, just return an 1D-empty tensor without checking dim | ||
bool all_1d_empty = true; | ||
for (size_t i = 0; i < tensors.size(); ++i) { | ||
if (tensors[i].numel() != 0 || tensors[i].dim() != 1) { | ||
all_1d_empty = false; | ||
break; | ||
} | ||
} | ||
if (all_1d_empty) { | ||
return out; | ||
} | ||
|
||
const size_t outer = getLeadingDims(out, dim); | ||
const size_t dim_stride = getTrailingDims(out, dim); | ||
const size_t ninputs = tensors.size(); | ||
|
||
const auto out_type = out.scalar_type(); | ||
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] { | ||
CTYPE_OUT* out_ptr = out.mutable_data_ptr<CTYPE_OUT>(); | ||
for (size_t i = 0; i < outer; ++i) { | ||
for (size_t j = 0; j < ninputs; ++j) { | ||
const auto in_type = tensors[j].scalar_type(); | ||
ET_SWITCH_REALHB_TYPES(in_type, ctx, name, CTYPE_IN, [&] { | ||
if (tensors[j].numel() == 0) { | ||
return; | ||
} | ||
size_t inner = tensors[j].size(dim) * dim_stride; | ||
const CTYPE_IN* const in_ptr = | ||
tensors[j].const_data_ptr<CTYPE_IN>() + i * inner; | ||
|
||
for (size_t k = 0; k < inner; ++k) { | ||
out_ptr[k] = static_cast<CTYPE_OUT>(in_ptr[k]); | ||
} | ||
out_ptr += inner; | ||
}); | ||
} | ||
} | ||
}); | ||
|
||
return out; | ||
} | ||
|
||
} // namespace native | ||
} // namespace HiFi | ||
} // namespace impl | ||
} // namespace cadence |
Oops, something went wrong.