Skip to content

Commit

Permalink
Adding bmm operator kernel optimization (#23)
Browse files Browse the repository at this point in the history
* Adding bmm operator kernel optimization

* Adding bmm operator kernel optimization

* Adding mm operator kernel optimization

---------

Co-authored-by: dijopaul <[email protected]>
  • Loading branch information
Rushi-cad and dijopaul authored Oct 23, 2024
1 parent 813e6c1 commit d653128
Show file tree
Hide file tree
Showing 7 changed files with 555 additions and 2 deletions.
7 changes: 6 additions & 1 deletion backends/cadence/aot/functions_hifi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
- op: bmm.out
kernels:
- arg_meta: null
kernel_name: torch::executor::bmm_out
kernel_name: impl::HiFi::bmm_out

- op: cat.out
kernels:
Expand Down Expand Up @@ -81,6 +81,11 @@
kernels:
- arg_meta: null
kernel_name: impl::HiFi::minimum_out

- op: mm.out
kernels:
- arg_meta: null
kernel_name: impl::HiFi::mm_out

- op: mul.out
kernels:
Expand Down
1 change: 1 addition & 0 deletions backends/cadence/hifi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_library(
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_mul_f32_broadcast.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_elm_where_f32xf32_f32.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_reduce_32_32.c
${EXECUTORCH_ROOT}/backends/cadence/hifi/third-party/nnlib/xa_nn_transpose_32.c
)

# Let files say "include <executorch/path/to/header.h>".
Expand Down
8 changes: 8 additions & 0 deletions backends/cadence/hifi/kernels/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ extern "C" WORD32 xa_nn_elm_where_broadcast_4D_f32xf32_f32(
const unsigned char* __restrict__ p_condition,
const WORD32* const p_condition_shape);

extern "C" WORD32 xa_nn_transpose_32_32(WORD32 * __restrict__ p_out,
const WORD32 *const p_out_shape,
const WORD32 * __restrict__ p_inp,
const WORD32 *const p_inp_shape,
const WORD32 * __restrict__ p_permute_vec,
WORD32 num_out_dims,
WORD32 num_inp_dims);

namespace impl {
namespace HiFi {
namespace kernels {
Expand Down
3 changes: 2 additions & 1 deletion backends/cadence/hifi/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,18 @@ endif()
# ATen compliant ops that are needed to run this model.
set(_aten_ops__srcs
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_add.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_bmm.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_clamp.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_div.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_maximum.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mean.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_minimum.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mm.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_mul.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sigmoid.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_sub.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_tanh.cpp"
"${EXECUTORCH_ROOT}/backends/cadence/hifi/operators/op_where.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_cat.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp"
"${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp"
Expand Down
156 changes: 156 additions & 0 deletions backends/cadence/hifi/operators/op_bmm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/kernels/portable/cpu/vec_ops.h>
#include <executorch/runtime/kernel/kernel_includes.h>

using Tensor = exec_aten::Tensor;
using exec_aten::ScalarType;
using executorch::runtime::KernelRuntimeContext;
using executorch::runtime::kTensorDimensionLimit;
using torch::executor::Error;

namespace impl {
namespace HiFi {
namespace native {

Tensor& bmm_out(
KernelRuntimeContext& ctx,
const Tensor& in,
const Tensor& mat2,
Tensor& out) {
ET_KERNEL_CHECK(ctx, check_bmm_args(in, mat2, out), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, mat2, out), InvalidArgument, out);

ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);

size_t output_ndim = 0;
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
get_bmm_out_target_size(in, mat2, output_sizes, &output_ndim);
ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
InvalidArgument,
out);

constexpr auto name = "bmm.out";
constexpr int kNnlibMaxDim = 3;

bool optimized = 1;

if (out.scalar_type() != ScalarType::Float)
optimized = 0;

if (in.dim() > kNnlibMaxDim)
optimized = 0;

if (optimized) {
const float* in_data = in.const_data_ptr<float>();
const float* mat2_data = mat2.const_data_ptr<float>();
float* out_data = out.mutable_data_ptr<float>();

int64_t batch_size = in.size(0);
int64_t m = in.size(1);
int64_t n = in.size(2);
int64_t p = mat2.size(2);

WORD32 rows = m;
WORD32 cols1 = n;
WORD32 row_stride1 = n;
WORD32 vec_count = p;
WORD32 vec_offset = n;
WORD32 out_offset = 1;
WORD32 out_stride = p;

float* tmp = (float*)calloc((batch_size * m * p), sizeof(float));
WORD32* p_o = (WORD32*)malloc((batch_size * m * p) * sizeof(float));

for (int i = 0; i < batch_size; ++i) {
const FLOAT32* __restrict__ p_mat1 = in_data + i * m * n;
const FLOAT32* __restrict__ p_vec1 = mat2_data + i * n * p;
FLOAT32* __restrict__ p_out = out_data + i * m * p;
const FLOAT32* __restrict__ p_bias = (const FLOAT32* __restrict__)tmp;

WORD32* p_inp = (WORD32*)p_vec1;

WORD32 p_inp_shape[kNnlibMaxDim];
p_inp_shape[0] = n;
p_inp_shape[1] = p;
p_inp_shape[2] = 1;

WORD32 p_out_shape[kNnlibMaxDim];
p_out_shape[0] = p;
p_out_shape[1] = n;
p_out_shape[2] = 1;

WORD32 p_permute_vec[kNnlibMaxDim] = {1, 0, 2};

WORD32 num_out_dims = kNnlibMaxDim;
WORD32 num_inp_dims = kNnlibMaxDim;

xa_nn_transpose_32_32(
p_o,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);

const FLOAT32* __restrict__ p_vec = (const FLOAT32* __restrict__)p_o;

xa_nn_matmul_f32xf32_f32(
p_out,
p_mat1,
p_vec,
p_bias,
rows,
cols1,
row_stride1,
vec_count,
vec_offset,
out_offset,
out_stride);
}

free(tmp);
free(p_o);

return out;
}

ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, name, CTYPE, [&]() {
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
const CTYPE* mat2_data = mat2.const_data_ptr<CTYPE>();
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();

int64_t batch_size = in.size(0);
int64_t m = in.size(1);
int64_t n = in.size(2);
int64_t p = mat2.size(2);

for (int i = 0; i < batch_size; ++i) {
const CTYPE* in_data_offset = in_data + i * m * n;
const CTYPE* mat2_data_offset = mat2_data + i * n * p;
CTYPE* out_data_offset = out_data + i * m * p;

torch::executor::vec_matmul<CTYPE>(
out_data_offset, in_data_offset, mat2_data_offset, m, n, p);
}
});

return out;
}

} // namespace native
} // namespace HiFi
} // namespace impl
141 changes: 141 additions & 0 deletions backends/cadence/hifi/operators/op_mm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/kernels/portable/cpu/util/matmul_ops_util.h>
#include <executorch/kernels/portable/cpu/vec_ops.h>
#include <executorch/runtime/kernel/kernel_includes.h>

using exec_aten::ScalarType;
using exec_aten::Tensor;
using executorch::runtime::KernelRuntimeContext;
using executorch::runtime::kTensorDimensionLimit;
using torch::executor::Error;

namespace impl {
namespace HiFi {
namespace native {

Tensor& mm_out(
KernelRuntimeContext& ctx,
const Tensor& in,
const Tensor& mat2,
Tensor& out) {
ET_KERNEL_CHECK(ctx, check_mm_args(in, mat2, out), InvalidArgument, out);

size_t output_ndim = 0;
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
get_mm_out_target_size(in, mat2, output_sizes, &output_ndim);
ET_KERNEL_CHECK(
ctx,
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
InvalidArgument,
out);

ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(in, mat2, out), InvalidArgument, out);

ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);

ScalarType out_type = out.scalar_type();

constexpr auto name = "mm.out";

bool optimized = 1;

if (out_type != ScalarType::Float)
optimized = 0;

if (optimized) {
const float* in_data = in.const_data_ptr<float>();
const float* mat2_data = mat2.const_data_ptr<float>();
float* out_data = out.mutable_data_ptr<float>();

int64_t m = in.size(0);
int64_t n = in.size(1);

int64_t p = mat2.size(1);

WORD32 rows = m;
WORD32 cols1 = n;
WORD32 row_stride1 = n;
WORD32 vec_count = p;
WORD32 vec_offset = n;
WORD32 out_offset = 1;
WORD32 out_stride = p;

WORD32* p_o = (WORD32*)malloc((n * p) * sizeof(float));

WORD32 p_inp_shape[2];
p_inp_shape[0] = n;
p_inp_shape[1] = p;

WORD32 p_out_shape[2];
p_out_shape[0] = p;
p_out_shape[1] = n;

WORD32 p_permute_vec[2] = {1, 0};

WORD32 num_out_dims = 2;
WORD32 num_inp_dims = 2;

const FLOAT32* __restrict__ p_mat1 = in_data;
const FLOAT32* __restrict__ p_vec1 = mat2_data;
FLOAT32* __restrict__ p_out = out_data;

WORD32* p_inp = (WORD32*)p_vec1;

WORD32 t = xa_nn_transpose_32_32(
p_o,
p_out_shape,
p_inp,
p_inp_shape,
p_permute_vec,
num_out_dims,
num_inp_dims);

const FLOAT32* __restrict__ p_vec = (const FLOAT32* __restrict__)p_o;

WORD32 val = xa_nn_matmul_f32xf32_f32(
p_out,
p_mat1,
p_vec,
NULL,
rows,
cols1,
row_stride1,
vec_count,
vec_offset,
out_offset,
out_stride);

free(p_o);
return out;
}

ET_SWITCH_REAL_TYPES_AND2(
Half, BFloat16, in.scalar_type(), ctx, name, CTYPE, [&]() {
size_t m = in.size(0);
size_t n = in.size(1);
size_t p = mat2.size(1);

torch::executor::vec_matmul<CTYPE>(
out.mutable_data_ptr<CTYPE>(),
in.const_data_ptr<CTYPE>(),
mat2.const_data_ptr<CTYPE>(),
m,
n,
p);
});

return out;
}

} // namespace native
} // namespace HiFi
} // namespace impl
Loading

0 comments on commit d653128

Please sign in to comment.