diff --git a/backends/cadence/aot/functions_hifi.yaml b/backends/cadence/aot/functions_hifi.yaml index 49cf856c49..31d61d5d94 100644 --- a/backends/cadence/aot/functions_hifi.yaml +++ b/backends/cadence/aot/functions_hifi.yaml @@ -30,7 +30,7 @@ - op: bmm.out kernels: - arg_meta: null - kernel_name: torch::executor::bmm_out + kernel_name: impl::HiFi::bmm_out - op: cat.out kernels: @@ -81,6 +81,11 @@ kernels: - arg_meta: null kernel_name: impl::HiFi::minimum_out + +- op: mm.out + kernels: + - arg_meta: null + kernel_name: impl::HiFi::mm_out - op: mul.out kernels: diff --git a/backends/cadence/hifi/kernels/CMakeLists.txt b/backends/cadence/hifi/kernels/CMakeLists.txt index e37b0ffafa..3811380559 100644 --- a/backends/cadence/hifi/kernels/CMakeLists.txt +++ b/backends/cadence/hifi/kernels/CMakeLists.txt @@ -18,6 +18,7 @@ add_library( ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c + ${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c ) # Let files say "include ". diff --git a/backends/cadence/hifi/kernels/kernels.h b/backends/cadence/hifi/kernels/kernels.h index c01cdd38d5..d728545f25 100644 --- a/backends/cadence/hifi/kernels/kernels.h +++ b/backends/cadence/hifi/kernels/kernels.h @@ -136,6 +136,14 @@ extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32( const unsigned char* __restrict__ p_condition, const WORD32* const p_condition_shape); +extern "C" WORD32 xa_nn_transpose_32_32(WORD32 * __restrict__ p_out, + const WORD32 *const p_out_shape, + const WORD32 * __restrict__ p_inp, + const WORD32 *const p_inp_shape, + const WORD32 * __restrict__ p_permute_vec, + WORD32 num_out_dims, + WORD32 num_inp_dims); + namespace impl { namespace HiFi { namespace kernels { diff --git a/backends/cadence/hifi/operators/CMakeLists.txt b/backends/cadence/hifi/operators/CMakeLists.txt index 4aea6d30a4..77fcbbbaa5 100644 --- a/backends/cadence/hifi/operators/CMakeLists.txt +++ b/backends/cadence/hifi/operators/CMakeLists.txt @@ -21,17 +21,18 @@ endif() # ATen compliant ops that are needed to run this model. set(_aten_ops__srcs "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_bmm.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_clamp.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_maximum.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_minimum.cpp" + "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mm.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp" "${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp" - "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp" diff --git a/backends/cadence/hifi/operators/op_bmm.cpp b/backends/cadence/hifi/operators/op_bmm.cpp new file mode 100644 index 0000000000..d9739e8cf5 --- /dev/null +++ b/backends/cadence/hifi/operators/op_bmm.cpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using Tensor = exec_aten::Tensor; +using exec_aten::ScalarType; +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::kTensorDimensionLimit; +using torch::executor::Error; + +namespace impl { +namespace HiFi { +namespace native { + +Tensor& bmm_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& mat2, + Tensor& out) { + ET_KERNEL_CHECK(ctx, check_bmm_args(in, mat2, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, mat2, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + + size_t output_ndim = 0; + exec_aten::SizesType output_sizes[kTensorDimensionLimit]; + get_bmm_out_target_size(in, mat2, output_sizes, &output_ndim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok, + InvalidArgument, + out); + + constexpr auto name = "bmm.out"; + constexpr int kNnlibMaxDim = 3; + + bool optimized = 1; + + if (out.scalar_type() != ScalarType::Float) + optimized = 0; + + if (in.dim() > kNnlibMaxDim) + optimized = 0; + + if (optimized) { + const float* in_data = in.const_data_ptr(); + const float* mat2_data = mat2.const_data_ptr(); + float* out_data = out.mutable_data_ptr(); + + int64_t batch_size = in.size(0); + int64_t m = in.size(1); + int64_t n = in.size(2); + int64_t p = mat2.size(2); + + WORD32 rows = m; + WORD32 cols1 = n; + WORD32 row_stride1 = n; + WORD32 vec_count = p; + WORD32 vec_offset = n; + WORD32 out_offset = 1; + WORD32 out_stride = p; + + float* tmp = (float*)calloc((batch_size * m * p), sizeof(float)); + WORD32* p_o = (WORD32*)malloc((batch_size * m * p) * sizeof(float)); + + for (int i = 0; i < batch_size; ++i) { + const FLOAT32* __restrict__ p_mat1 = in_data + i * m * n; + const FLOAT32* __restrict__ p_vec1 = mat2_data + i * n * p; + FLOAT32* __restrict__ p_out = out_data + i * m * p; + const FLOAT32* __restrict__ p_bias = (const FLOAT32* __restrict__)tmp; + + WORD32* p_inp = (WORD32*)p_vec1; + + WORD32 p_inp_shape[kNnlibMaxDim]; + p_inp_shape[0] = n; + p_inp_shape[1] = p; + p_inp_shape[2] = 1; + + WORD32 p_out_shape[kNnlibMaxDim]; + p_out_shape[0] = p; + p_out_shape[1] = n; + p_out_shape[2] = 1; + + WORD32 p_permute_vec[kNnlibMaxDim] = {1, 0, 2}; + + WORD32 num_out_dims = kNnlibMaxDim; + WORD32 num_inp_dims = kNnlibMaxDim; + + xa_nn_transpose_32_32( + p_o, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + num_out_dims, + num_inp_dims); + + const FLOAT32* __restrict__ p_vec = (const FLOAT32* __restrict__)p_o; + + xa_nn_matmul_f32xf32_f32( + p_out, + p_mat1, + p_vec, + p_bias, + rows, + cols1, + row_stride1, + vec_count, + vec_offset, + out_offset, + out_stride); + } + + free(tmp); + free(p_o); + + return out; + } + + ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, name, CTYPE, [&]() { + const CTYPE* in_data = in.const_data_ptr(); + const CTYPE* mat2_data = mat2.const_data_ptr(); + CTYPE* out_data = out.mutable_data_ptr(); + + int64_t batch_size = in.size(0); + int64_t m = in.size(1); + int64_t n = in.size(2); + int64_t p = mat2.size(2); + + for (int i = 0; i < batch_size; ++i) { + const CTYPE* in_data_offset = in_data + i * m * n; + const CTYPE* mat2_data_offset = mat2_data + i * n * p; + CTYPE* out_data_offset = out_data + i * m * p; + + torch::executor::vec_matmul( + out_data_offset, in_data_offset, mat2_data_offset, m, n, p); + } + }); + + return out; +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/operators/op_mm.cpp b/backends/cadence/hifi/operators/op_mm.cpp new file mode 100644 index 0000000000..ebbc7f705d --- /dev/null +++ b/backends/cadence/hifi/operators/op_mm.cpp @@ -0,0 +1,141 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::kTensorDimensionLimit; +using torch::executor::Error; + +namespace impl { +namespace HiFi { +namespace native { + +Tensor& mm_out( + KernelRuntimeContext& ctx, + const Tensor& in, + const Tensor& mat2, + Tensor& out) { + ET_KERNEL_CHECK(ctx, check_mm_args(in, mat2, out), InvalidArgument, out); + + size_t output_ndim = 0; + exec_aten::SizesType output_sizes[kTensorDimensionLimit]; + get_mm_out_target_size(in, mat2, output_sizes, &output_ndim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok, + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, mat2, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + + ScalarType out_type = out.scalar_type(); + + constexpr auto name = "mm.out"; + + bool optimized = 1; + + if (out_type != ScalarType::Float) + optimized = 0; + + if (optimized) { + const float* in_data = in.const_data_ptr(); + const float* mat2_data = mat2.const_data_ptr(); + float* out_data = out.mutable_data_ptr(); + + int64_t m = in.size(0); + int64_t n = in.size(1); + + int64_t p = mat2.size(1); + + WORD32 rows = m; + WORD32 cols1 = n; + WORD32 row_stride1 = n; + WORD32 vec_count = p; + WORD32 vec_offset = n; + WORD32 out_offset = 1; + WORD32 out_stride = p; + + WORD32* p_o = (WORD32*)malloc((n * p) * sizeof(float)); + + WORD32 p_inp_shape[2]; + p_inp_shape[0] = n; + p_inp_shape[1] = p; + + WORD32 p_out_shape[2]; + p_out_shape[0] = p; + p_out_shape[1] = n; + + WORD32 p_permute_vec[2] = {1, 0}; + + WORD32 num_out_dims = 2; + WORD32 num_inp_dims = 2; + + const FLOAT32* __restrict__ p_mat1 = in_data; + const FLOAT32* __restrict__ p_vec1 = mat2_data; + FLOAT32* __restrict__ p_out = out_data; + + WORD32* p_inp = (WORD32*)p_vec1; + + WORD32 t = xa_nn_transpose_32_32( + p_o, + p_out_shape, + p_inp, + p_inp_shape, + p_permute_vec, + num_out_dims, + num_inp_dims); + + const FLOAT32* __restrict__ p_vec = (const FLOAT32* __restrict__)p_o; + + WORD32 val = xa_nn_matmul_f32xf32_f32( + p_out, + p_mat1, + p_vec, + NULL, + rows, + cols1, + row_stride1, + vec_count, + vec_offset, + out_offset, + out_stride); + + free(p_o); + return out; + } + + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, name, CTYPE, [&]() { + size_t m = in.size(0); + size_t n = in.size(1); + size_t p = mat2.size(1); + + torch::executor::vec_matmul( + out.mutable_data_ptr(), + in.const_data_ptr(), + mat2.const_data_ptr(), + m, + n, + p); + }); + + return out; +} + +} // namespace native +} // namespace HiFi +} // namespace impl diff --git a/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c b/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c new file mode 100644 index 0000000000..cbcdec8811 --- /dev/null +++ b/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c @@ -0,0 +1,241 @@ +#include "xa_nnlib_common.h" +#include "stdio.h" +/* + * Currently only supports upto 5D input tensors. + * 1/2/3/4 D input tensors will be scaled up to 5D. + * For example, 2x3 -> 1x1x1x2x3. + */ + +WORD32 xa_nn_transpose_32_32(WORD32 * __restrict__ p_out + ,const WORD32 *const p_out_shape + ,const WORD32 * __restrict__ p_inp + ,const WORD32 *const p_inp_shape + ,const WORD32 * __restrict__ p_permute_vec + ,WORD32 num_out_dims + ,WORD32 num_inp_dims) +{ + /* NULL pointer checks */ + XA_NNLIB_ARG_CHK_PTR(p_out, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp, -1); + XA_NNLIB_ARG_CHK_PTR(p_permute_vec, -1); + XA_NNLIB_ARG_CHK_PTR(p_out_shape, -1); + XA_NNLIB_ARG_CHK_PTR(p_inp_shape, -1); + + /* Invalid input checks */ + XA_NNLIB_ARG_CHK_COND(((num_inp_dims <= 0) || (num_inp_dims > 5)), -1); + XA_NNLIB_ARG_CHK_COND((num_out_dims != num_inp_dims), -1); + + int itr = 0; + for(itr=0; itr < num_inp_dims; itr++) + { + XA_NNLIB_ARG_CHK_COND((p_inp_shape[itr] <= 0), -1); + } + for(itr=0; itr < num_out_dims; itr++) + { + XA_NNLIB_ARG_CHK_COND((p_out_shape[itr] <= 0), -1); + } + + + /* Output shape provided must be correct based on input + * shape and permute values */ + for(itr=0; itr < num_out_dims; itr++) + { + int output_dim = p_out_shape[itr]; + int expected_dim = p_inp_shape[p_permute_vec[itr]]; + XA_NNLIB_ARG_CHK_COND((output_dim != expected_dim), -1); + } + + /* Pointer alignment checks */ + XA_NNLIB_ARG_CHK_ALIGN(p_out, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_permute_vec, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_out_shape, sizeof(WORD32), -1); + XA_NNLIB_ARG_CHK_ALIGN(p_inp_shape, sizeof(WORD32), -1); + + /* Shift all dim with 1 in the outer part */ + int eff_output_shape[5]; + int eff_permute_vec[5]; + + for(int i = 0; i < num_out_dims; i++) + { + eff_output_shape[i] = p_out_shape[i]; + eff_permute_vec[i] = p_permute_vec[i]; + } + + int one_i=num_out_dims-1, non_one_i=num_out_dims-1; + while(one_i > 0 && non_one_i >=0){ + while(one_i > 0 && eff_output_shape[one_i]!=1){ + one_i--; + } + non_one_i = one_i; + while(non_one_i >= 0 && eff_output_shape[non_one_i]==1) + { + non_one_i--; + } + if(one_i > 0 && non_one_i >=0){ + int temp; + /*swap output_shape*/ + { + temp = eff_output_shape[one_i]; + eff_output_shape[one_i] = eff_output_shape[non_one_i]; + eff_output_shape[non_one_i] = temp; + } + /*swap permute_vec*/ + { + temp = eff_permute_vec[one_i]; + eff_permute_vec[one_i] = eff_permute_vec[non_one_i]; + eff_permute_vec[non_one_i] = temp; + } + + } + } + + /* Promoting lesser dim tensors to 5D tensors. + * Also updating the permute_vec and shapes as needed for optimization */ + int p_5D_inp_shape[5] = {1, 1, 1, 1, 1}; + int p_5D_out_shape[5] = {1, 1, 1, 1, 1}; + int p_5D_permute_vec[5] = {0, 1, 2, 3, 4}; + + /* Check if any inner inp dimension is same in the output */ + int last_dim_same = 1, last_n_same_dim = 0; + itr = num_inp_dims - 1; + while(itr >= 0) + { + last_n_same_dim = (last_dim_same && (eff_permute_vec[itr] == itr)) ? (last_n_same_dim + 1) : last_n_same_dim; + last_dim_same = (eff_permute_vec[itr] == itr) ? last_dim_same & 1 : last_dim_same & 0; + itr--; + } + + int dims_added = 5 - num_inp_dims; + itr = num_inp_dims - 1; + int same_count = last_n_same_dim; + int count = 4; + while(itr >= 0) + { + p_5D_inp_shape[count] = (same_count > 0) ? p_5D_inp_shape[count]*p_inp_shape[itr] : p_inp_shape[itr]; + p_5D_out_shape[count] = (same_count > 0) ? p_5D_out_shape[count]*eff_output_shape[itr] : eff_output_shape[itr]; + same_count--; + itr--; + count = (same_count > 0) ? count : count - 1; + } + + itr = num_inp_dims - 1; + same_count = (last_n_same_dim) ? num_inp_dims - (last_n_same_dim - 1) : 0; + count = 4; + while(itr >= 0) + { + p_5D_permute_vec[count] = (same_count > 0) ? eff_permute_vec[itr-(last_n_same_dim - 1)] + dims_added + last_n_same_dim - 1 : eff_permute_vec[itr] + dims_added; + same_count--; + itr--; + count--; + } + + int out_dim0, out_dim1, out_dim2, out_dim3, out_dim4; + int inp_dim1, inp_dim2, inp_dim3, inp_dim4; + int inp_stride[5]; + + out_dim0 = p_5D_out_shape[0]; + out_dim1 = p_5D_out_shape[1]; + out_dim2 = p_5D_out_shape[2]; + out_dim3 = p_5D_out_shape[3]; + out_dim4 = p_5D_out_shape[4]; + + inp_dim1 = p_5D_inp_shape[1]; + inp_dim2 = p_5D_inp_shape[2]; + inp_dim3 = p_5D_inp_shape[3]; + inp_dim4 = p_5D_inp_shape[4]; + + inp_stride[0] = inp_dim1*inp_dim2*inp_dim3*inp_dim4; + inp_stride[1] = inp_dim2*inp_dim3*inp_dim4; + inp_stride[2] = inp_dim3*inp_dim4; + inp_stride[3] = inp_dim4; + inp_stride[4] = 1; + + if(last_n_same_dim) + { + int itr0, itr1, itr2, itr3, itr4; + WORD32 *p_inp0 = (WORD32 *)p_inp; + for(itr0 = 0; itr0 < out_dim0; itr0++) + { + WORD32 *p_inp1 = p_inp0+(itr0*inp_stride[p_5D_permute_vec[0]]); +#pragma loop_count min=1 + for(itr1 = 0; itr1 < out_dim1; itr1++) + { + WORD32 *p_inp2 = p_inp1+(itr1*inp_stride[p_5D_permute_vec[1]]); +#pragma loop_count min=1 + for(itr2 = 0; itr2 < out_dim2; itr2++) + { + WORD32 *p_inp3 = p_inp2+(itr2*inp_stride[p_5D_permute_vec[2]]); +#pragma loop_count min=1 + for(itr3 = 0; itr3 < out_dim3; itr3++, p_out+=out_dim4) + { + WORD32 *p_inp4 = p_inp3+(itr3*inp_stride[p_5D_permute_vec[3]]); + ae_int32x2 *__restrict__ pae_i = (ae_int32x2 *)(p_inp4); + ae_int32x2 *__restrict__ pae_o = (ae_int32x2 *)(p_out); + ae_valign a_inp = AE_LA64_PP(pae_i); + ae_valign a_out = AE_ZALIGN64(); + ae_int32x2 d0; + for(itr4 = 0; itr4 < (out_dim4 >> 1); itr4++) + { + AE_LA32X2_IP(d0, a_inp, pae_i); + AE_SA32X2_IP(d0, a_out, pae_o); + } + AE_SA64POS_FP(a_out, pae_o); + ae_int32 *__restrict__ puae_i = (ae_int32 *)(pae_i); + ae_int32 *__restrict__ puae_o = (ae_int32 *)(pae_o); +#pragma loop_count max=3 + for(itr4 = 0; itr4 < (out_dim4 & 1); itr4++) + { + puae_o[itr4] = puae_i[itr4]; + } + } + } + } + } + } + else + { + int itr0, itr1, itr2, itr3, itr4; + WORD32 *p_inp0 = (WORD32 *)p_inp; + for(itr0 = 0; itr0 < out_dim0; itr0++) + { + WORD32 *p_inp1 = p_inp0+(itr0*inp_stride[p_5D_permute_vec[0]]); + for(itr1 = 0; itr1 < out_dim1; itr1++) + { + WORD32 *p_inp2 = p_inp1+(itr1*inp_stride[p_5D_permute_vec[1]]); + for(itr2 = 0; itr2 < out_dim2; itr2++) + { + WORD32 *p_inp3 = p_inp2+(itr2*inp_stride[p_5D_permute_vec[2]]); + for(itr3 = 0; itr3 < out_dim3; itr3++) + { + WORD32 *p_inp4 = p_inp3+(itr3*inp_stride[p_5D_permute_vec[3]]); + + ae_valign a_out = AE_ZALIGN64(); + for(itr4 = 0; itr4 < (out_dim4 >> 1); itr4++) + { + ae_int32x2 d0, d1; + ae_int32x2 tmp0; + + d0 = AE_L32_X((ae_int32 *)p_inp4, 0); + p_inp4 += inp_stride[p_5D_permute_vec[4]]; + d1 = AE_L32_X((ae_int32 *)p_inp4, 0); + p_inp4 += inp_stride[p_5D_permute_vec[4]]; + + tmp0 = AE_SEL32_HH(d0, d1); + + AE_SA32X2_IP(tmp0, a_out, (ae_int32x2 *)p_out); + } + AE_SA64POS_FP(a_out, p_out); +#pragma loop_count max=3 + for(itr4 = 0; itr4 < (out_dim4 & 1); itr4++) + { + *p_out++ = *p_inp4; + } + } + } + } + } + } + + return 0; +} \ No newline at end of file