diff --git a/.gitmodules b/.gitmodules index d1ab8b9aa7..58f2133ed6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -64,6 +64,9 @@ [submodule "third-party/pybind11"] path = third-party/pybind11 url = https://github.com/pybind/pybind11.git +[submodule "backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3"] + path = backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 + url = https://github.com/foss-xtensa/nnlib-FusionG3/ [submodule "third-party/ao"] path = third-party/ao url = https://github.com/pytorch/ao.git diff --git a/backends/cadence/CMakeLists.txt b/backends/cadence/CMakeLists.txt index 3c1aa2945a..3cd880622c 100644 --- a/backends/cadence/CMakeLists.txt +++ b/backends/cadence/CMakeLists.txt @@ -76,7 +76,12 @@ endif() if(EXECUTORCH_NNLIB_OPT) set(TARGET_DIR hifi) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/third-party/nnlib) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) +endif() + +if(EXECUTORCH_FUSION_G3_OPT) + set(TARGET_DIR fusion_g3) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/third-party/nnlib) endif() add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/operators) -add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/${TARGET_DIR}/kernels) diff --git a/backends/cadence/aot/functions_fusion_g3.yaml b/backends/cadence/aot/functions_fusion_g3.yaml new file mode 100644 index 0000000000..2c162e1444 --- /dev/null +++ b/backends/cadence/aot/functions_fusion_g3.yaml @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This yaml file contains operators that are also defined by the ATen library. +# For lean mode: +# - Codegen'd target `executorch_generated_lib` will be reading all the information +# from this file, including operator schema and kernel metadata. +# - Selective build target `codegen:executorch_defined_ops` now is selecting all the +# operators in this file, by dumping all the op names into `selected_operators.yaml`. +# +# See the README.md file in executorch/kernels/portable for a description of the syntax used +# by this file. + + +# aten ops +- op: _to_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::to_copy_out + +- op: _softmax.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::softmax_out + +- op: add.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::add_out + +- op: add.Scalar_out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::add_scalar_out + +- op: bmm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::bmm_out + +- op: cat.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::cat_out + +- op: clone.out + kernels: + - arg_meta: null + kernel_name: torch::executor::clone_out + +- op: div.out + kernels: + - arg_meta: null + kernel_name: torch::executor::div_out + +- op: div.out_mode + kernels: + - arg_meta: null + kernel_name: torch::executor::div_out_mode + +- op: embedding.out + kernels: + - arg_meta: null + kernel_name: torch::executor::embedding_out + +- op: full.out + kernels: + - arg_meta: null + kernel_name: torch::executor::full_out + +- op: mul.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::mul_out + +- op: mul.Scalar_out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::mul_scalar_out + +- op: permute_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::permute_copy_out + +- op: sigmoid.out + kernels: + - arg_meta: null + kernel_name: torch::executor::sigmoid_out + +- op: slice_copy.Tensor_out + kernels: + - arg_meta: null + kernel_name: torch::executor::slice_copy_Tensor_out + +- op: split_with_sizes_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::split_with_sizes_copy_out + +- op: sub.out + kernels: + - arg_meta: null + kernel_name: torch::executor::sub_out + +- op: view_copy.out + kernels: + - arg_meta: null + kernel_name: torch::executor::view_copy_out + +- op: where.self_out + kernels: + - arg_meta: null + kernel_name: torch::executor::where_out + +- op: native_layer_norm.out + kernels: + - arg_meta: null + kernel_name: cadence::impl::G3::native_layer_norm_out \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/CMakeLists.txt b/backends/cadence/fusion_g3/operators/CMakeLists.txt new file mode 100644 index 0000000000..704b4aa741 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/CMakeLists.txt @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.19) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +if(NOT CMAKE_CXX_STANDARD) + set(CMAKE_CXX_STANDARD 17) +endif() + +include(${EXECUTORCH_ROOT}/build/Utils.cmake) +include(${EXECUTORCH_ROOT}/build/Codegen.cmake) + +if(NOT PYTHON_EXECUTABLE) + resolve_python_executable() +endif() + +# ATen compliant ops that are needed to run this model. +set(_aten_ops__srcs + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/activation_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/copy_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/broadcast_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/index_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/kernel_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/matmul_ops_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/reduce_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/repeat_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/slice_util.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_add.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_mul.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_cat.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_softmax.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_native_layer_norm.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_quantize.cpp" + "${CMAKE_CURRENT_SOURCE_DIR}/op_dequantize.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_bmm.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_clone.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_div.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_embedding.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_full.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_permute_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sigmoid.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_slice_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_split_with_sizes_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_sub.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_to_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_view_copy.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/op_where.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/dtype_util.cpp" + "${EXECUTORCH_ROOT}/kernels/portable/cpu/util/normalization_ops_util.cpp" +) +add_library(aten_ops_cadence ${_aten_ops__srcs}) +target_link_libraries(aten_ops_cadence PUBLIC executorch) +target_link_libraries(aten_ops_cadence PRIVATE xa_nnlib) + +# Let files say "include ". +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +target_include_directories( + aten_ops_cadence PUBLIC ${ROOT_DIR}/.. ${CMAKE_BINARY_DIR} + ${_common_include_directories} + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/algo/common/include/ + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/include/nnlib + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/include + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/algo/kernels/tables/include +) + +# Generate C++ bindings to register kernels into both PyTorch (for AOT) and +# Executorch (for runtime). Here select all ops in functions.yaml +gen_selected_ops( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML + "${CMAKE_CURRENT_LIST_DIR}/../../aot/functions_fusion_g3.yaml" "" "" +) +generate_bindings_for_kernels( + LIB_NAME "cadence_ops_lib" OPS_SCHEMA_YAML FUNCTIONS_YAML + ${CMAKE_CURRENT_SOURCE_DIR}/../../aot/functions_fusion_g3.yaml +) +message("Generated files ${gen_command_sources}") + +gen_operators_lib( + LIB_NAME "cadence_ops_lib" KERNEL_LIBS DEPS aten_ops_cadence +) diff --git a/backends/cadence/fusion_g3/operators/op_add.cpp b/backends/cadence/fusion_g3/operators/op_add.cpp new file mode 100644 index 0000000000..6dc710ce6e --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_add.cpp @@ -0,0 +1,257 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::runtime::canCast; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& add_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (canCast(common_type, out.scalar_type()) && + torch::executor::check_alpha_type( + torch::executor::native::utils::get_scalar_dtype(alpha), + common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, b, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.out"; + + const exec_aten::ArrayRef a_size = a.sizes(); + const exec_aten::ArrayRef b_size = b.sizes(); + const exec_aten::ArrayRef out_size = out.sizes(); + + int kTensorDimensionLimit = 5; + + int inp1_shape[kTensorDimensionLimit]; + int inp2_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + /* input shapes and output shapes */ + for (auto i = 0; i < a_size.size(); i++) { + inp1_shape[i] = a_size[i]; + } + + for (auto i = 0; i < b_size.size(); i++) { + inp2_shape[i] = b_size[i]; + } + + for (auto i = 0; i < out_size.size(); i++) { + out_shape[i] = out_size[i]; + } + + /*find broadcast*/ + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool broadcast = (a_is_broadcasted || b_is_broadcasted); + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + const int* const inp2_data = b.const_data_ptr(); + int* const out_data = out.mutable_data_ptr(); + + int alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + if (broadcast) { + xa_nn_elm_add_broadcast_5D_32x32_32( + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim, + alpha_val); + } else { + xa_nn_elm_add_32x32_32( + out_data, inp1_data, inp2_data, alpha_val, out.numel()); + } + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + const float* const inp2_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + float alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + if (broadcast) { + xa_nn_elm_add_broadcast_5D_f32xf32_f32( + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim, + alpha_val); + } else { + xa_nn_elm_add_f32xf32_f32( + out_data, inp1_data, inp2_data, alpha_val, out.numel()); + } + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_alpha = + torch::executor::native::utils::scalar_to(alpha); + torch::executor::native::utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name>( + [val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a + val_alpha * val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + b, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16); + }); + } + + return out; +} + +Tensor& add_scalar_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + const Scalar& alpha, + Tensor& out) { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, + (common_type == out.scalar_type() && + torch::executor::check_alpha_type( + torch::executor::native::utils::get_scalar_dtype(alpha), + common_type)), + InvalidArgument, + out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor(out, a.sizes()) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "add.Scalar_out"; + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + int inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + int alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + int* const out_data = out.mutable_data_ptr(); + + xa_nn_elm_add_scalar_32x32_32( + out_data, inp1_data, inp2_val, alpha_val, out.numel()); + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + float inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + + float alpha_val; + torch::executor::native::utils::extract_scalar(alpha, &alpha_val); + + float* const out_data = out.mutable_data_ptr(); + + xa_nn_elm_add_scalar_f32xf32_f32( + out_data, inp1_data, inp2_val, alpha_val, out.numel()); + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + torch::executor::native::utils:: + apply_unitensor_elementwise_fn( + [b, alpha](const CTYPE_COMPUTE val_a) { + CTYPE_COMPUTE val_b = + torch::executor::native::utils::scalar_to(b); + CTYPE_COMPUTE val_alpha = + torch::executor::native::utils::scalar_to( + alpha); + return val_a + val_alpha * val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes:: + SAME_AS_COMMON); + }); + } + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_cat.cpp b/backends/cadence/fusion_g3/operators/op_cat.cpp new file mode 100644 index 0000000000..62bbb0c9d4 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_cat.cpp @@ -0,0 +1,202 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +/* ScalarType in Executorch do not have support for below data types. + * So, creating a placeholder for these data types. Once, ScalarTypes is + * updated to have support for below data types, these can be removed and + * operator need to be updated accordingly + */ +enum datatype { + Ushort = 20, + Uint = 23, +}; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& cat_out( + KernelRuntimeContext& ctx, + exec_aten::ArrayRef tensors, + int64_t dim, + Tensor& out) { + if (dim < 0) { + dim += out.dim(); + } + + ET_KERNEL_CHECK( + ctx, + torch::executor::check_cat_args(tensors, dim, out), + InvalidArgument, + out); + + int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; + size_t expected_out_dim = 0; + torch::executor::get_cat_out_target_size( + tensors, dim, expected_out_size, &expected_out_dim); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor( + out, {expected_out_size, expected_out_dim}) == Error::Ok, + InvalidArgument, + out); + + const signed char* inp_tensors[tensors.size()]; + const int* inp_tensors_shapes[tensors.size()]; + + int inp_shapes_size[tensors.size()]; + + int temp_sizes[tensors.size()][kTensorDimensionLimit]; + exec_aten::ArrayRef temp_size; + + for (int i = 0; i < tensors.size(); i++) { + inp_tensors[i] = tensors[i].const_data_ptr(); + temp_size = tensors[i].sizes(); + + for (int j = 0; j < temp_size.size(); j++) { + temp_sizes[i][j] = temp_size[j]; + } + inp_tensors_shapes[i] = temp_sizes[i]; // input shapes + inp_shapes_size[i] = temp_size.size(); // number of input dimensions + } + + signed char* out_data = out.mutable_data_ptr(); + + const exec_aten::ArrayRef out_size = out.sizes(); + int out_shapes[kTensorDimensionLimit]; + for (int i = 0; i < out_size.size(); i++) // output shapes + { + out_shapes[i] = out_size[i]; + } + + if (out.scalar_type() == ScalarType::Int) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(int)); + } else if (out.scalar_type() == ScalarType::Short) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(short)); + } else if (out.scalar_type() == ScalarType::Char) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(char)); + } + if (out.scalar_type() == (ScalarType)Uint) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(int)); + } else if (out.scalar_type() == (ScalarType)Ushort) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(short)); + } else if (out.scalar_type() == ScalarType::Byte) { + xa_nn_cat( + out_data, + out_shapes, + inp_tensors, + inp_tensors_shapes, + inp_shapes_size[0], + tensors.size(), + (int)dim, + sizeof(char)); + + } else { + // Special handling when all inputs are 1D-empty tensors for aten + // consistency In that case, just return an 1D-empty tensor without checking + // dim + bool all_1d_empty = true; + for (size_t i = 0; i < tensors.size(); ++i) { + if (tensors[i].numel() != 0 || tensors[i].dim() != 1) { + all_1d_empty = false; + break; + } + } + if (all_1d_empty) { + return out; + } + + const size_t outer = executorch::runtime::getLeadingDims(out, dim); + const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim); + const size_t ninputs = tensors.size(); + + const auto out_type = out.scalar_type(); + ET_SWITCH_REALHB_TYPES(out_type, ctx, "cat.out", CTYPE_OUT, [&] { + CTYPE_OUT* out_ptr = out.mutable_data_ptr(); + for (size_t i = 0; i < outer; ++i) { + for (size_t j = 0; j < ninputs; ++j) { + const auto in_type = tensors[j].scalar_type(); + ET_SWITCH_REALHB_TYPES(in_type, ctx, "cat.out", CTYPE_IN, [&] { + if (tensors[j].numel() == 0) { + return; + } + size_t inner = tensors[j].size(dim) * dim_stride; + const CTYPE_IN* const in_ptr = + tensors[j].const_data_ptr() + i * inner; + + for (size_t k = 0; k < inner; ++k) { + out_ptr[k] = static_cast(in_ptr[k]); + } + out_ptr += inner; + }); + } + } + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp new file mode 100644 index 0000000000..784011332f --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp @@ -0,0 +1,810 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +template +using optional = exec_aten::optional; +/* ScalarType in Executorch do not have support for below data types. + * So, creating a placeholder for these data types. Once, ScalarTypes is + * updated to have support for below data types, these can be removed and + * operator need to be updated accordingly + */ + + enum datatype { + Ushort = 20, + Bits4u = 21, + Bits4 = 22 + }; + +/** + * For an input tensor, use the scale and zero_point arguments to quantize it. + */ +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +namespace { + +/** + * Asserts that the parameters are valid. + */ +void check_dequantize_per_tensor_args(const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional& out_dtype, + Tensor& out) +{ + ET_CHECK_MSG( + input.scalar_type() == ScalarType::Byte || + input.scalar_type() == ScalarType::Char || + input.scalar_type() == ScalarType::Bits16 || + input.scalar_type() == ScalarType::Short || + input.scalar_type() == (ScalarType) Ushort || + input.scalar_type() == (ScalarType) Bits4 || + input.scalar_type() == (ScalarType) Bits4u || + input.scalar_type() == ScalarType::Int, + + "input.scalar_type() %" PRId8 " is not supported:", + static_cast(input.scalar_type())); + + ET_CHECK_MSG( + input.scalar_type() == dtype, + "input.scalar_type() %" PRId8 " is not matching dtype argumenta:", + static_cast(input.scalar_type())); + + if (out_dtype.has_value()) { + ET_CHECK_MSG( + out.scalar_type() == out_dtype.value(), + "output_dtype must match the dtype of the out tensor"); + } + + ET_CHECK_MSG( + quant_min <= quant_max, + "quant min: %" PRId64 " is greater than quant max: %" PRId64, + quant_min, + quant_max); +} + +} // namespace + + +/* Local function which calls the kernels based on the input datatype */ +void Dequantize_impl(Tensor& out, + const Tensor& input, + float *scale_data, + int *zero_point_data, + int *axis, + exec_aten::optional out_dtype) +{ + const exec_aten::ArrayRef input_size = input.sizes(); + + int kTensorDimensionLimit = 5; + + int inp_shape[kTensorDimensionLimit]; + + for(auto i = 0; i < input_size.size(); i++) + { + inp_shape[i] = input_size[i]; + } + + bool is_asym_dequant = 0; + + if(zero_point_data != NULL) //asymmetric dequant + { + if(axis != NULL) //channel + { + for(int i = 0; i < input.size(*axis) ; i++) + { + if(zero_point_data[i] != 0) + { + is_asym_dequant |= 1; + } + } + } + else + { + if(*zero_point_data != 0) //tesor + { + is_asym_dequant |= 1; + } + } + } + float* out_data = out.mutable_data_ptr(); + + if(is_asym_dequant) + { + if (input.scalar_type() == ScalarType::Byte) + { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8u_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == ScalarType::Char) + { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == (ScalarType) Ushort) + { + const uint16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym16u_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == ScalarType::Short) + { + const int16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym16_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == (ScalarType) Bits4u) + { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym4u_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else if (input.scalar_type() == (ScalarType) Bits4) + { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym4_f32( + out_data, input_data, inp_shape, input.dim(), axis, + zero_point_data, scale_data); + } + else + { + if(axis == NULL) + { + // calculate the dequantized output, cast scale to float to match fbgemm + // behavior + #define ASYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + out_data_ptr[i] = static_cast( \ + (input_data_ptr[i] - static_cast(*zero_point_data)) * \ + static_cast(*scale_data)); \ + } \ + } break; + #define ASYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_TESNOR); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR); + ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + #undef ASYM_CALCULATE_INT_TYPE_TENSOR + #undef ASYM_DEQUANTIZE_IMPL_TESNOR + } + else + { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) + { + if (i < *axis) + { + dims[i] = i; + } + else + { + dims[i] = i + 1; + } + } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + + // Actual dequantization logic + // input, out are the input and output tensors + // channel_ix is the index along the axis dimension. 0 <= channel_ix < + // input.size(axis). + // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix + // will be 0, 1, 2, ... C-1 + // in_ix is the flat index of the element you are dequantizing. + // in other words you are dequantizing in_data[in_ix] + #define ASYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + if (input.dim() == 1) { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + ET_CHECK_MSG( \ + *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ + const optional dim; \ + torch::executor::apply_over_dim( \ + [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ + size_t numel, size_t stride, size_t base_ix) { \ + for (size_t i = 0; i < numel; i++) { \ + size_t current_ix = base_ix * stride + i; \ + float _scale = scale_data[current_ix]; \ + int64_t zero_point = 0; \ + if (zero_point_data != nullptr) { \ + zero_point = zero_point_data[current_ix]; \ + } \ + out_data_ptr[current_ix] = \ + static_cast( \ + input_data_ptr[current_ix] - zero_point) * \ + _scale; \ + } \ + }, \ + input, \ + dim); \ + break; \ + } \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ + float _scale = scale_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ + out_data_ptr[in_ix] = static_cast( \ + (input_data_ptr[in_ix] - _zero_point) * _scale); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; + #define ASYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_CHANNEL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL); + ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + #undef ASYM_CALCULATE_INT_TYPE_CHANNEL + #undef ASYM_DEQUANTIZE_IMPL_CHANNEL + } + } + } + else + { + if (input.scalar_type() == ScalarType::Byte) + { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym8u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == ScalarType::Char) + { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym8_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == (ScalarType) Ushort) + { + const uint16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym16u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == ScalarType::Short) + { + const int16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym16_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == (ScalarType) Bits4u) + { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym4u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else if (input.scalar_type() == (ScalarType) Bits4) + { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym4_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } + else + { + if(axis == NULL) + { + // calculate the dequantized output, cast scale to float to match fbgemm + // behavior + #define SYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + out_data_ptr[i] = static_cast( \ + (input_data_ptr[i] - static_cast(*zero_point_data)) * \ + static_cast(*scale_data)); \ + } \ + } break; + #define SYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_TESNOR); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR); + SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + #undef SYM_DEQUANTIZE_IMPL_TESNOR + #undef SYM_CALCULATE_INT_TYPE_TENSOR + } + else + { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) + { + if (i < *axis) + { + dims[i] = i; + } + else + { + dims[i] = i + 1; + } + } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + + // Actual dequantization logic + // input, out are the input and output tensors + // channel_ix is the index along the axis dimension. 0 <= channel_ix < + // input.size(axis). + // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix + // will be 0, 1, 2, ... C-1 + // in_ix is the flat index of the element you are dequantizing. + // in other words you are dequantizing in_data[in_ix] + #define SYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + if (input.dim() == 1) { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + ET_CHECK_MSG( \ + *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ + const optional dim; \ + torch::executor::apply_over_dim( \ + [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ + size_t numel, size_t stride, size_t base_ix) { \ + for (size_t i = 0; i < numel; i++) { \ + size_t current_ix = base_ix * stride + i; \ + float _scale = scale_data[current_ix]; \ + int64_t zero_point = 0; \ + if (zero_point_data != nullptr) { \ + zero_point = zero_point_data[current_ix]; \ + } \ + out_data_ptr[current_ix] = \ + static_cast( \ + input_data_ptr[current_ix] - zero_point) * \ + _scale; \ + } \ + }, \ + input, \ + dim); \ + break; \ + } \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ + float _scale = scale_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ + out_data_ptr[in_ix] = static_cast( \ + (input_data_ptr[in_ix] - _zero_point) * _scale); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; + #define SYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_CHANNEL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL); + SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + #undef SYM_DEQUANTIZE_IMPL_CHANNEL + #undef SYM_CALCULATE_INT_TYPE_CHANNEL + } + } + } +} + +/** + * Dequantizes the input tensor according to the formula (input - zero_point) * + * scale + * + * NOTE: quant_min and quant_max are not used in computation, but rather + * metadata that is passed around which can be useful for pattern matching. See + * https://github.com/pytorch/pytorch/pull/87093#discussion_r1000841181 for more + * info. + */ +Tensor& dequantize_per_tensor_out( + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_tensor_out"); + + check_dequantize_per_tensor_args( + input, quant_min, quant_max, dtype, out_dtype, out); + + float scale_data = (float)scale; + int zero_point_data = (int)zero_point; + + Dequantize_impl(out, + input, + &scale_data, + &zero_point_data, + NULL, + out_dtype); + + return out; +} + +Tensor& dequantize_per_tensor_tensor_args_out(const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "Expected scale to be Double tensor received: %" PRId8, + static_cast(scale.scalar_type())); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "Expected scale to be Long tensor received: %" PRId8, + static_cast(zero_point.scalar_type())); + ET_CHECK_MSG( + scale.numel() == 1, + "Exepcted scale to only have one element received: %zd", + ssize_t(scale.numel())); + ET_CHECK_MSG( + zero_point.numel() == 1, + "Exepcted zero_point to only have one element received: %zd", + ssize_t(zero_point.numel())); + + dequantize_per_tensor_out( + input, + scale.const_data_ptr()[0], + zero_point.const_data_ptr()[0], + quant_min, + quant_max, + dtype, + out_dtype, + out); + + return out; +} + +Tensor& dequantize_per_channel_out(const Tensor& input, + const Tensor& scale, + const exec_aten::optional& opt_zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + torch::executor::Error err = resize_tensor(out, input.sizes()); + + // normalize axis + ET_CHECK_MSG( + executorch::runtime::tensor_has_dim(input, axis), + "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", + ssize_t(axis), + ssize_t(input.dim())); + + if (axis < 0) + { + axis += executorch::runtime::nonzero_dim(input); + } + + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_channel_out"); + + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "scale.scalar_type() %" PRId8 " is not double type", + static_cast(scale.scalar_type())); + + ET_CHECK_MSG( + scale.numel() == input.size(axis), + "scale.numel() %zd != input.size(axis) %zd", + ssize_t(scale.numel()), + ssize_t(input.size(axis))); + + if (opt_zero_points.has_value()) { + auto zero_point = opt_zero_points.value(); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "zero_point.scalar_type() %" PRId8 " is not integer type", + static_cast(zero_point.scalar_type())); + + ET_CHECK_MSG( + zero_point.numel() == input.size(axis), + "zero_point.numel() %zd != input.size(axis) %zd", + ssize_t(zero_point.numel()), + ssize_t(input.size(axis))); + } + + check_dequantize_per_tensor_args( + input, quant_min, quant_max, dtype, out_dtype, out); + + int *axis_ptr = (int *)&axis; + + const double* scale_dt = scale.const_data_ptr(); + const int64_t* zero_point_dt; + int zero_point_data[input.size(axis)]; + int *zero_point_ptr; + if (opt_zero_points.has_value()) + { + zero_point_dt = opt_zero_points.value().const_data_ptr(); + zero_point_ptr = &zero_point_data[0]; + for(int i = 0; i < scale.numel(); i++) + { + zero_point_ptr[i] = (int)zero_point_dt[i]; + } + } + else + { + zero_point_ptr = nullptr; + } + float scale_data[input.size(axis)]; + for(int i = 0; i < scale.numel(); i++) + { + scale_data[i] = (float)scale_dt[i]; + } + Dequantize_impl(out, + input, + scale_data, + zero_point_ptr, + axis_ptr, + out_dtype); + + return out; +} + +Tensor& dequantize_per_channel_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const exec_aten::optional& opt_zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + (void)context; + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_channel_out"); + return dequantize_per_channel_out( + input, + scale, + opt_zero_points, + axis, + quant_min, + quant_max, + dtype, + out_dtype, + out); +} + +Tensor& dequantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + return dequantize_per_tensor_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +Tensor& dequantize_per_tensor_tensor_args_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) +{ + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + return dequantize_per_tensor_tensor_args_out( + input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); +} + +Tensor& dequantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) +{ + // Refactor this into a util + size_t num_channels = 1; + for (size_t i = 0; i < input.dim() - 1; i++) + { + num_channels *= input.size(i); + } + // This unfortunate change is needed because we compile op_quantize for aten + // mode as well + std::array input_sizes; + input_sizes[0] = static_cast(num_channels); + input_sizes[1] = + static_cast(input.size(input.dim() - 1)); +#ifdef USE_ATEN_LIB + Tensor reshaped_input = at::from_blob( + input.mutable_data_ptr(), + input_sizes, + at::TensorOptions(input.scalar_type())); +#else + std::array input_dim_order{0, 1}; + std::array input_strides; + executorch::runtime::dim_order_to_stride_nocheck( + input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); + void* input_data = input.mutable_data_ptr(); + torch::executor::TensorImpl reshaped_input_impl = executorch::runtime::etensor::TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + executorch::runtime::TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in dequantize_per_channel_out"); +#endif + + return dequantize_per_channel_out( + reshaped_input, + scale, + zero_points, + 0, /* axis */ + quant_min, + quant_max, + dtype, + out_dtype, + out); +} + +Tensor& dequantize_per_token_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_points, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + ScalarType out_dtype, + Tensor& out) +{ + (void)context; + return dequantize_per_token_out( + input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_mul.cpp b/backends/cadence/fusion_g3/operators/op_mul.cpp new file mode 100644 index 0000000000..366982ae3f --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_mul.cpp @@ -0,0 +1,214 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using executorch::runtime::canCast; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& mul_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = + executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type()); + + // Check Common Dtype + ET_KERNEL_CHECK( + ctx, canCast(common_type, out.scalar_type()), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, b, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, + torch::executor::resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "mul.out"; + + const exec_aten::ArrayRef a_size = a.sizes(); + const exec_aten::ArrayRef b_size = b.sizes(); + const exec_aten::ArrayRef out_size = out.sizes(); + + int kTensorDimensionLimit = 5; + + int inp1_shape[kTensorDimensionLimit]; + int inp2_shape[kTensorDimensionLimit]; + int out_shape[kTensorDimensionLimit]; + + /* input shapes and output shapes */ + for (auto i = 0; i < a_size.size(); i++) { + inp1_shape[i] = a_size[i]; + } + + for (auto i = 0; i < b_size.size(); i++) { + inp2_shape[i] = b_size[i]; + } + + for (auto i = 0; i < out_size.size(); i++) { + out_shape[i] = out_size[i]; + } + + /*find broadcast*/ + const bool a_is_broadcasted = !out.sizes().equals(a.sizes()); + const bool b_is_broadcasted = !out.sizes().equals(b.sizes()); + const bool broadcast = (a_is_broadcasted || b_is_broadcasted); + + int max_dim = a.dim() > b.dim() ? a.dim() : b.dim(); + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + const int* const inp2_data = b.const_data_ptr(); + int* const out_data = out.mutable_data_ptr(); + + if (broadcast) { + xa_nn_elm_mul_broadcast_5D_32x32_32( + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim); + } else { + xa_nn_elm_mul_32x32_32(out_data, inp1_data, inp2_data, out.numel()); + } + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + const float* const inp2_data = b.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + + if (broadcast) { + xa_nn_elm_mul_broadcast_5D_f32xf32_f32( + out_data, + out_shape, + inp1_data, + inp1_shape, + inp2_data, + inp2_shape, + max_dim); + } else { + xa_nn_elm_mul_f32xf32_f32(out_data, inp1_data, inp2_data, out.numel()); + } + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + torch::executor::native::utils::apply_bitensor_elementwise_fn< + CTYPE_COMPUTE, + op_name>( + [](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) { + return val_a * val_b; + }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + b, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16); + }); + } + + return out; +} + +Tensor& mul_scalar_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = + torch::executor::native::utils::promote_type_with_scalar( + a.scalar_type(), b); + + // Check Common Dtype + ET_KERNEL_CHECK(ctx, common_type == out.scalar_type(), InvalidArgument, out); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(a, out), + InvalidArgument, + out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = + torch::executor::native::utils::get_compute_type(common_type); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "mul.Scalar_out"; + + if (compute_type == ScalarType::Int) { + const int* const inp1_data = a.const_data_ptr(); + int inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + int* const out_data = out.mutable_data_ptr(); + + xa_nn_elm_mul_scalar_32x32_32(out_data, inp1_data, inp2_val, out.numel()); + } else if (compute_type == ScalarType::Float) { + const float* const inp1_data = a.const_data_ptr(); + float inp2_val; + torch::executor::native::utils::extract_scalar(b, &inp2_val); + float* const out_data = out.mutable_data_ptr(); + + xa_nn_elm_mul_scalar_f32xf32_f32( + out_data, inp1_data, inp2_val, out.numel()); + } else { + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = + torch::executor::native::utils::scalar_to(b); + torch::executor::native::utils:: + apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { return val_a * val_b; }, + ctx, + a, + torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16, + out, + torch::executor::native::utils::SupportedTensorDtypes:: + SAME_AS_COMMON); + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp new file mode 100644 index 0000000000..68d111795c --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_native_layer_norm.cpp @@ -0,0 +1,258 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using Tensor = exec_aten::Tensor; +using ScalarType = exec_aten::ScalarType; +using IntArrayRef = exec_aten::ArrayRef; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +namespace { + +template +void layer_norm( + const Tensor& input, + IntArrayRef normalized_shape, + const exec_aten::optional& weight, + const exec_aten::optional& bias, + CTYPE eps, + Tensor& out, + Tensor& mean, + Tensor& rstd) { + size_t dim = input.dim() - normalized_shape.size(); + size_t dim_size = input.size(dim); + + size_t leading = executorch::runtime::getLeadingDims(input, dim); + size_t normalized = + executorch::runtime::getTrailingDims(input, dim) * dim_size; + + if (leading == 0) { + return; + } + + CTYPE* out_data = out.mutable_data_ptr(); + CTYPE* mean_data = mean.mutable_data_ptr(); + CTYPE* rstd_data = rstd.mutable_data_ptr(); + + if (normalized == 0) { + for (int i = 0; i < leading; ++i) { + mean_data[i] = static_cast(0); + rstd_data[i] = static_cast(NAN); + } + return; + } + + const CTYPE* input_data = input.const_data_ptr(); + const CTYPE* weight_data; + if (weight.has_value()) { + weight_data = weight.value().const_data_ptr(); + } else { + weight_data = nullptr; + } + const CTYPE* bias_data; + if (bias.has_value()) { + bias_data = bias.value().const_data_ptr(); + } else { + bias_data = nullptr; + } + + for (int i = 0; i < leading; ++i) { + const CTYPE* x = input_data + i * normalized; + CTYPE* y = out_data + i * normalized; + + // compute E[X] and Var[x] = E[x^2] - E[x]^2 + CTYPE sum = torch::executor::reduce_add(x, normalized); + CTYPE sq_sum = torch::executor::vec_powerf(x, normalized); + CTYPE mean_value = sum / normalized; + CTYPE variance = sq_sum / normalized - mean_value * mean_value; + CTYPE std = std::sqrt(variance + eps); + + // Calculate the elements of output + for (int j = 0; j < normalized; ++j) { + CTYPE w = weight_data ? weight_data[j] : static_cast(1); + CTYPE b = bias_data ? bias_data[j] : static_cast(0); + y[j] = (x[j] - mean_value) / std * w + b; + } + + mean_data[i] = mean_value; + rstd_data[i] = 1.0 / std; + } +} + +} // namespace + +// native_layer_norm.out(Tensor input, int[] normalized_shape, Tensor? weight, +// Tensor? bias, float eps, *, Tensor(a!) out, Tensor(b!) mean_out, Tensor(c!) +// rstd_out) -> (Tensor(a!), Tensor(b!), Tensor(c!)) +// As a reference, there's math_native_layer_norm in ATen: +// https://www.internalfb.com/code/fbsource/[2da5b17b086554c6cd0c3ab08a35aeec2a8bad8c]/xplat/caffe2/aten/src/ATen/native/layer_norm.cpp?lines=188 +std::tuple native_layer_norm_out( + KernelRuntimeContext& ctx, + const Tensor& input, + IntArrayRef normalized_shape, + const exec_aten::optional& weight, + const exec_aten::optional& bias, + double eps, + Tensor& out, + Tensor& mean_out, + Tensor& rstd_out) { + (void)ctx; + + std::tuple ret_val(out, mean_out, rstd_out); + + ET_KERNEL_CHECK( + ctx, + torch::executor::check_layer_norm_args( + input, normalized_shape, weight, bias, out, mean_out, rstd_out), + InvalidArgument, + ret_val); + + // Only support default dim order for now. + // TODO: Support other dim orders. + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_default_dim_order(input), + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order( + input, out, mean_out, rstd_out), + InvalidArgument, + ret_val); + + if (weight.has_value()) { + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(input, weight.value()), + InvalidArgument, + ret_val); + } + + if (bias.has_value()) { + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(input, bias.value()), + InvalidArgument, + ret_val); + } + int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit; + Tensor::SizesType mean_rstd_sizes[kTensorDimensionLimit]; + size_t mean_rstd_ndim = 0; + torch::executor::get_layer_norm_out_target_size( + input, normalized_shape, mean_rstd_sizes, &mean_rstd_ndim); + + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, input.sizes()) == Error::Ok, + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor( + mean_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok, + InvalidArgument, + ret_val); + + ET_KERNEL_CHECK( + ctx, + executorch::runtime::resize_tensor( + rstd_out, {mean_rstd_sizes, mean_rstd_ndim}) == Error::Ok, + InvalidArgument, + ret_val); + + int input_shape[kTensorDimensionLimit]; + for (int i = 0; i < input.dim(); i++) { + input_shape[i] = input.size(i); + } + + if (out.scalar_type() == ScalarType::Float) { + float* const out_data = out.mutable_data_ptr(); + float* const mean_data = mean_out.mutable_data_ptr(); + float* const rstd_data = rstd_out.mutable_data_ptr(); + const float* const inp_data = input.const_data_ptr(); + int dim = input.dim() - normalized_shape.size(); + + int num_elm = 1; + for (int i = 0; i < normalized_shape.size(); i++) { + num_elm *= normalized_shape[i]; + } + + float* weight_data; + if (weight.has_value()) { + weight_data = weight.value().mutable_data_ptr(); + } else { + weight_data = (float*)malloc(num_elm * sizeof(float)); + for (int i = 0; i < num_elm; i++) { + weight_data[i] = 1; + } + } + float* bias_data; + if (bias.has_value()) { + bias_data = bias.value().mutable_data_ptr(); + } else { + bias_data = (float*)malloc(num_elm * sizeof(float)); + for (int i = 0; i < num_elm; i++) { + bias_data[i] = 0; + } + } + + xa_nn_native_layer_norm_f32_f32( + out_data, + mean_data, + rstd_data, + inp_data, + input_shape, + input.dim(), + dim, + weight_data, + bias_data, + (float)eps); + + if (!bias.has_value()) { + free(bias_data); + } + if (!weight.has_value()) { + free(weight_data); + } + } else { + ET_SWITCH_FLOAT_TYPES( + input.scalar_type(), ctx, "native_layer_norm.out", CTYPE, [&]() { + layer_norm( + input, + normalized_shape, + weight, + bias, + eps, + out, + mean_out, + rstd_out); + }); + } + + return ret_val; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp new file mode 100644 index 0000000000..bc84829edb --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp @@ -0,0 +1,797 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +/* ScalarType in Executorch do not have support for below data types. + * So, creating a placeholder for these data types. Once, ScalarTypes is + * updated to have support for below data types, these can be removed and + * operator need to be updated accordingly + */ + enum datatype { + Ushort = 20, + Bits4u = 21, + Bits4 = 22 + }; + +/** + * For an input tensor, use the scale and zero_point arguments to quantize it. + */ +namespace cadence { +namespace impl { +namespace FusionG3 { +namespace native { + + +namespace { + +/** + * Asserts that the parameters are valid. + */ +void check_quantize_per_tensor_args(const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + // Ensure self and out has the same shape + ET_CHECK_MSG( + torch::executor::isFloatingType(input.scalar_type()), + "input.scalar_type() %" PRId8 " is not floating type", + static_cast(input.scalar_type())); + + int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; + ScalarType out_dtype = out.scalar_type(); + ET_CHECK_MSG( + out_dtype == dtype, + "out.scalar_type() %" PRId8 " is not matching dtype argument %" PRId8, + static_cast(out_dtype), + static_cast(dtype)); + + if (out_dtype == ScalarType::Byte) + { + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + } + else if (dtype == ScalarType::Char) + { + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + } + else if (dtype == ScalarType::Bits16) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } + else if (dtype == ScalarType::Short) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } + else if (dtype == (ScalarType)Ushort) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } + else if (dtype == (ScalarType)Bits4u) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + /* Minimum and maximum values fo unsigned 4-bit data type */ + quant_min_lower_bound = quant_min_lower_bound >> 4; + quant_max_upper_bound = quant_max_upper_bound >> 4; + } + else if (dtype == (ScalarType)Bits4) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + /* Minimum and maximum values fo signed 4-bit data type */ + quant_min_lower_bound = quant_min_lower_bound >> 4; + quant_max_upper_bound = quant_max_upper_bound >> 4; + } + else if (dtype == ScalarType::Int) + { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } + else + { + ET_CHECK_MSG( + false, "Unsupported dtype: %" PRId8, static_cast(out_dtype)); + } + + ET_CHECK_MSG( + quant_min >= quant_min_lower_bound, + "quant_min out of bound for dtype, expected quant_min_lower_bound: %" PRId32 + " actual quant_min: %" PRId64, + quant_min_lower_bound, + quant_min); + + ET_CHECK_MSG( + quant_max <= quant_max_upper_bound, + "quant_max out of bound for dtype, expected quant_max_upper_bound: %" PRId32 + " actual quant_max: %" PRId64, + quant_max_upper_bound, + quant_max); +}/* check_quantize_per_tensor_args */ + +} // namespace + +template +T quantize_val( + double scale, + int64_t zero_point, + K value, + int64_t quant_min, + int64_t quant_max) +{ + int64_t qvalue; + float inv_scale = 1.0f / static_cast(scale); + qvalue = static_cast( + static_cast(zero_point) + + std::nearbyint(static_cast(inv_scale * value))); + + qvalue = std::max(qvalue, quant_min); + qvalue = std::min(qvalue, quant_max); + return static_cast(qvalue); +} + + +/* Local function which calls the kernels based on the output datatype */ +void quantize_impl(Tensor& out, + const Tensor& input, + float *scale_data, + int *zero_point_data, + int *axis, + int quant_min, + int quant_max) +{ + const exec_aten::ArrayRef input_size = input.sizes(); + + int kTensorDimensionLimit = 5; + + int inp_shape[kTensorDimensionLimit]; + + for(auto i = 0; i < input_size.size(); i++) + { + inp_shape[i] = input_size[i]; + } + + const float* input_data = input.const_data_ptr(); + + bool is_asym_quant = 0; + + if(zero_point_data != NULL) //asymmetric quant + { + if(axis != NULL) //channel + { + for(int i = 0; i < input.size(*axis) ; i++) + { + if(zero_point_data[i] != 0) + { + is_asym_quant |= 1; + } + } + } + else + { + if(*zero_point_data != 0) //tensor + { + is_asym_quant |= 1; + } + } + } + + if(is_asym_quant) + { + if (out.scalar_type() == ScalarType::Byte) + { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == ScalarType::Char) + { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType)Ushort) + { + uint16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym16u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == ScalarType::Short) + { + int16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym16( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType)Bits4u) + { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym4u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType)Bits4) + { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym4( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, zero_point_data, quant_min, quant_max); + } + else + { + if(axis == NULL) + { + // Vector quantization + // calculate the quantized input + #define ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + (double)*scale_data, (int64_t)*zero_point_data, value, \ + (int64_t)quant_min, (int64_t)quant_max); \ + } \ + } break; + #define ASYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(IN_CTYPE, ASYM_QUANTIZE_IMPL_TENSOR); \ + ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_TENSOR); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + + } + else + { + // Channel based quantization + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) + { + if (i < *axis) + { + dims[i] = i; + } + else + { + dims[i] = i + 1; + } + } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + + // Actual quantization logic + // input, out are the input and output tensors + // channel_ix is the index along the axis dimension. 0 <= channel_ix < + // input.size(axis). + // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix + // will be 0, 1, 2, ... C-1 + // in_ix is the flat index of the element you are quantizing. + // in other words you are quantizing in_data[in_ix] + #define ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ + double _scale = (double)scale_data[channel_ix]; \ + int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, \ + out_data_ptr, \ + _scale, \ + _zero_point, \ + quant_min, \ + quant_max](size_t in_ix) { \ + out_data_ptr[in_ix] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[in_ix], \ + quant_min, \ + quant_max); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; + #define ASYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(CTYPE_IN, ASYM_QUANTIZE_IMPL_CHANNEL); \ + ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_CHANNEL); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + } + + #undef ASYM_CALCULATE_FLOAT_TYPE_TENSOR + #undef ASYM_CALCULATE_FLOAT_TYPE_CHANNEL + #undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR + #undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL + } + } + else + { + if (out.scalar_type() == ScalarType::Byte) + { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym8u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == ScalarType::Char) + { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym8( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType) Ushort) + { + uint16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym16u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == ScalarType::Short) + { + int16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym16( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType) Bits4u) + { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym4u( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else if (out.scalar_type() == (ScalarType) Bits4) + { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym4( + out_data, input_data, inp_shape, input.dim(), axis, + scale_data, quant_min, quant_max); + } + else + { + if(axis == NULL) + { + // calculate the quantized input + #define SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + (double)*scale_data, (int64_t)*zero_point_data, value, \ + (int64_t)quant_min, (int64_t)quant_max); \ + } \ + } break; + #define SYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(IN_CTYPE, SYM_QUANTIZE_IMPL_TENSOR); \ + SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_TENSOR); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + + } + else + { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) + { + if (i < *axis) + { + dims[i] = i; + } + else + { + dims[i] = i + 1; + } + } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + + // Actual quantization logic + // input, out are the input and output tensors + // channel_ix is the index along the axis dimension. 0 <= channel_ix < + // input.size(axis). + // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix + // will be 0, 1, 2, ... C-1 + // in_ix is the flat index of the element you are quantizing. + // in other words you are quantizing in_data[in_ix] + #define SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ + double _scale = (double)scale_data[channel_ix]; \ + int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, \ + out_data_ptr, \ + _scale, \ + _zero_point, \ + quant_min, \ + quant_max](size_t in_ix) { \ + out_data_ptr[in_ix] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[in_ix], \ + quant_min, \ + quant_max); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; + #define SYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(CTYPE_IN, SYM_QUANTIZE_IMPL_CHANNEL); \ + SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_CHANNEL); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); + } + } + #undef SYM_CALCULATE_FLOAT_TYPE_TENSOR + #undef SYM_CALCULATE_FLOAT_TYPE_CHANNEL + #undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR + #undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL + } + } +} + +// Quantize the input tensor +Tensor& quantize_per_tensor_out(KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_tensor_out"); + + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + + float scale_data = (float)scale; + int zero_point_data = (int)zero_point; + quantize_impl(out, + input, + &scale_data, + &zero_point_data, + NULL, + (int) quant_min, + (int) quant_max); + + return out; +} + + +Tensor& quantize_per_tensor_tensor_args_out(KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + // Temporary change to allow not fatal failure for now to unblock some + // expected failure tests that are dying instead of failure. Will revisit + // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal + // failures. + if (scale.scalar_type() != ScalarType::Double) + { + context.fail(torch::executor::Error::InvalidArgument); + return out; + } + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "Expected scale to be Double tensor received: %" PRId8, + static_cast(scale.scalar_type())); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "Expected zero_point to be Long tensor received: %" PRId8, + static_cast(zero_point.scalar_type())); + ET_CHECK_MSG( + scale.numel() == 1, + "Exepcted scale to only have one element received: %zd", + ssize_t(scale.numel())); + ET_CHECK_MSG( + zero_point.numel() == 1, + "Exepcted zero_point to only have one element received: %zd", + ssize_t(zero_point.numel())); + + quantize_per_tensor_out(context, + input, + scale.const_data_ptr()[0], + zero_point.const_data_ptr()[0], + quant_min, + quant_max, + dtype, + out); + + return out; +} + +Tensor& quantize_per_tensor_tensor_args_out(const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + auto context = torch::executor::RuntimeContext(); + auto& res = quantize_per_tensor_tensor_args_out( + context, input, scale, zero_point, quant_min, quant_max, dtype, out); + ET_CHECK(context.failure_state() == Error::Ok); + return res; +} + +Tensor& quantize_per_channel_out(const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + torch::executor::Error err = resize_tensor(out, input.sizes()); + + // normalize axis + ET_CHECK_MSG( + executorch::runtime::tensor_has_dim(input, axis), + "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", + ssize_t(axis), + ssize_t(input.dim())); + + if (axis < 0) + { + axis += executorch::runtime::nonzero_dim(input); + } + + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_channel_out"); + + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "scale.scalar_type() %" PRId8 " is not double type", + static_cast(scale.scalar_type())); + + ET_CHECK_MSG( + scale.numel() == input.size(axis), + "scale.numel() %zd != input.size(axis) %zd", + scale.numel(), + input.size(axis)); + + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "zero_point.scalar_type() %" PRId8 " is not integer type", + static_cast(zero_point.scalar_type())); + + ET_CHECK_MSG( + zero_point.numel() == input.size(axis), + "zero_point.numel() %zd != input.size(axis) %zd", + zero_point.numel(), + input.size(axis)); + + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + + + const double* scale_dt = scale.const_data_ptr(); + const int64_t* zero_point_dt = zero_point.const_data_ptr(); + + float scale_data[input.size(axis)]; + int zero_point_data[input.size(axis)]; + + for(int i = 0; i < scale.numel(); i++) + { + scale_data[i] = (float)scale_dt[i]; + zero_point_data[i] = (int)zero_point_dt[i]; + } + + int *axis_ptr = (int *)&axis; + + quantize_impl(out, + input, + scale_data, + zero_point_data, + axis_ptr, + (int) quant_min, + (int) quant_max); + + return out; +} + +Tensor& quantize_per_channel_out(KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + (void)context; + return quantize_per_channel_out( + input, scale, zero_point, axis, quant_min, quant_max, dtype, out); +} + +Tensor& quantize_per_token_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + size_t num_tokens = 1; + for (size_t i = 0; i < input.dim() - 1; i++) + { + num_tokens *= input.size(i); + } + // This unfortunate change is needed because we compile op_quantize for aten + // mode as well +#ifdef USE_ATEN_LIB + std::vector sizes(2); + sizes[0] = num_tokens; + sizes[1] = input.size(input.dim() - 1); + Tensor reshaped_input = at::from_blob( + input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type())); +#else + std::array input_dim_order{0, 1}; + std::array input_sizes; + input_sizes[0] = num_tokens; + input_sizes[1] = input.size(input.dim() - 1); + std::array input_strides; + executorch::runtime::dim_order_to_stride_nocheck( + input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); + void* input_data = input.mutable_data_ptr(); + torch::executor::TensorImpl reshaped_input_impl = executorch::runtime::etensor::TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + executorch::runtime::TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_channel_out"); +#endif + + return quantize_per_channel_out( + reshaped_input, scale, zero_point, 0, quant_min, quant_max, dtype, out); +} + +Tensor& quantize_per_token_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) +{ + (void)context; + return quantize_per_token_out( + input, scale, zero_point, quant_min, quant_max, dtype, out); +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp new file mode 100644 index 0000000000..79ec6dc5d7 --- /dev/null +++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::KernelRuntimeContext; + +namespace cadence { +namespace impl { +namespace G3 { +namespace native { + +Tensor& softmax_out( + KernelRuntimeContext& ctx, + const Tensor& in, + int64_t dim, + bool half_to_float, + Tensor& out) +{ + (void)ctx; + + ET_KERNEL_CHECK( + ctx, + torch::executor::check_softmax_args(in, dim, half_to_float, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, executorch::runtime::tensors_have_same_dim_order(in, out), InvalidArgument, out); + + // Adjust for negative dim + dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; + + int inp_shapes[in.dim()]; + const exec_aten::ArrayRef in_size = in.sizes(); + for(int i = 0; i < in.dim(); i++) + { + inp_shapes[i] = in_size[i]; + } + + if(out.scalar_type() == ScalarType::Float) + { + const float * const inp_data = in.const_data_ptr(); + float * const out_data = out.mutable_data_ptr(); + int axis = dim; + xa_nn_softmax_f32_f32(out_data, inp_data, inp_shapes, + in.dim(), &axis); + } + else + { + ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* const out_data = out.mutable_data_ptr(); + + torch::executor::apply_over_dim( + [in_data, out_data]( + const size_t size, const size_t stride, const size_t base) { + // calculate max in softmax dim. During softmax computation each + // value is subtracted by the maximum in value before calling exp + // to preserve numerical stability. + const CTYPE max_in = torch::executor::apply_unary_reduce_fn( + [](const CTYPE val_in, CTYPE val_accum) { + return std::max(val_in, val_accum); + }, + in_data + base, + size, + stride); + + const CTYPE temp_sum = torch::executor:: + apply_unary_map_reduce_fn( + [max_in](const CTYPE val_in) { + return std::exp(val_in - max_in); + }, + [](const CTYPE mapped_in, CTYPE val_accum) { + return val_accum + mapped_in; + }, + in_data + base, + size, + stride); + + torch::executor::apply_unary_map_fn( + [max_in, temp_sum](const CTYPE val_in) { + return std::exp(val_in - max_in) / temp_sum; + }, + in_data + base, + out_data + base, + size, + stride); + }, + in, + dim); + }); + } + + return out; +} + +} // namespace native +} // namespace G3 +} // namespace impl +} // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/third-party/nnlib/CMakeLists.txt b/backends/cadence/fusion_g3/third-party/nnlib/CMakeLists.txt new file mode 100644 index 0000000000..a2615e0851 --- /dev/null +++ b/backends/cadence/fusion_g3/third-party/nnlib/CMakeLists.txt @@ -0,0 +1,19 @@ +cmake_minimum_required(VERSION 3.10.0) +project(cadence_nnlib) + +add_custom_target( + nnlib_target ALL + COMMAND + make install_nnlib -f makefile -C + ${EXECUTORCH_ROOT}/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3/xa_nnlib/build + OBJDIR=${CMAKE_CURRENT_BINARY_DIR}/obj + LIBDIR=${CMAKE_CURRENT_BINARY_DIR}/lib -j8 +) + +add_library(xa_nnlib STATIC IMPORTED GLOBAL) +add_dependencies(xa_nnlib nnlib_target) + +set_property( + TARGET xa_nnlib PROPERTY IMPORTED_LOCATION + "${CMAKE_CURRENT_BINARY_DIR}/lib/xa_nnlib.a" +) diff --git a/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 b/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 new file mode 160000 index 0000000000..8ddd1c39d4 --- /dev/null +++ b/backends/cadence/fusion_g3/third-party/nnlib/nnlib-FusionG3 @@ -0,0 +1 @@ +Subproject commit 8ddd1c39d4b20235ebe9dac68d92848da2885ece