diff --git a/backends/cadence/fusion_g3/operators/op_dequantize.cpp b/backends/cadence/fusion_g3/operators/op_dequantize.cpp index 784011332f..f450ed398f 100644 --- a/backends/cadence/fusion_g3/operators/op_dequantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_dequantize.cpp @@ -23,20 +23,16 @@ template using optional = exec_aten::optional; /* ScalarType in Executorch do not have support for below data types. * So, creating a placeholder for these data types. Once, ScalarTypes is - * updated to have support for below data types, these can be removed and + * updated to have support for below data types, these can be removed and * operator need to be updated accordingly */ - - enum datatype { - Ushort = 20, - Bits4u = 21, - Bits4 = 22 - }; + +enum datatype { Ushort = 20, Bits4u = 21, Bits4 = 22 }; /** * For an input tensor, use the scale and zero_point arguments to quantize it. */ -namespace cadence { +namespace cadence { namespace impl { namespace G3 { namespace native { @@ -46,38 +42,38 @@ namespace { /** * Asserts that the parameters are valid. */ -void check_dequantize_per_tensor_args(const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - exec_aten::optional& out_dtype, - Tensor& out) -{ - ET_CHECK_MSG( +void check_dequantize_per_tensor_args( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional& out_dtype, + Tensor& out) { + ET_CHECK_MSG( input.scalar_type() == ScalarType::Byte || input.scalar_type() == ScalarType::Char || input.scalar_type() == ScalarType::Bits16 || input.scalar_type() == ScalarType::Short || - input.scalar_type() == (ScalarType) Ushort || - input.scalar_type() == (ScalarType) Bits4 || - input.scalar_type() == (ScalarType) Bits4u || + input.scalar_type() == (ScalarType)Ushort || + input.scalar_type() == (ScalarType)Bits4 || + input.scalar_type() == (ScalarType)Bits4u || input.scalar_type() == ScalarType::Int, - + "input.scalar_type() %" PRId8 " is not supported:", static_cast(input.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( input.scalar_type() == dtype, "input.scalar_type() %" PRId8 " is not matching dtype argumenta:", static_cast(input.scalar_type())); - if (out_dtype.has_value()) { + if (out_dtype.has_value()) { ET_CHECK_MSG( out.scalar_type() == out_dtype.value(), "output_dtype must match the dtype of the out tensor"); - } + } - ET_CHECK_MSG( + ET_CHECK_MSG( quant_min <= quant_max, "quant min: %" PRId64 " is greater than quant max: %" PRId64, quant_min, @@ -86,412 +82,395 @@ void check_dequantize_per_tensor_args(const Tensor& input, } // namespace - /* Local function which calls the kernels based on the input datatype */ -void Dequantize_impl(Tensor& out, - const Tensor& input, - float *scale_data, - int *zero_point_data, - int *axis, - exec_aten::optional out_dtype) -{ - const exec_aten::ArrayRef input_size = input.sizes(); +void Dequantize_impl( + Tensor& out, + const Tensor& input, + float* scale_data, + int* zero_point_data, + int* axis, + exec_aten::optional out_dtype) { + const exec_aten::ArrayRef input_size = input.sizes(); - int kTensorDimensionLimit = 5; + int kTensorDimensionLimit = 5; - int inp_shape[kTensorDimensionLimit]; + int inp_shape[kTensorDimensionLimit]; - for(auto i = 0; i < input_size.size(); i++) - { - inp_shape[i] = input_size[i]; - } + for (auto i = 0; i < input_size.size(); i++) { + inp_shape[i] = input_size[i]; + } - bool is_asym_dequant = 0; + bool is_asym_dequant = 0; - if(zero_point_data != NULL) //asymmetric dequant + if (zero_point_data != NULL) // asymmetric dequant + { + if (axis != NULL) // channel { - if(axis != NULL) //channel - { - for(int i = 0; i < input.size(*axis) ; i++) - { - if(zero_point_data[i] != 0) - { - is_asym_dequant |= 1; - } + for (int i = 0; i < input.size(*axis); i++) { + if (zero_point_data[i] != 0) { + is_asym_dequant |= 1; } } - else + } else { + if (*zero_point_data != 0) // tesor { - if(*zero_point_data != 0) //tesor - { - is_asym_dequant |= 1; - } + is_asym_dequant |= 1; } } - float* out_data = out.mutable_data_ptr(); - - if(is_asym_dequant) - { - if (input.scalar_type() == ScalarType::Byte) - { - const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym8u_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); - } - else if (input.scalar_type() == ScalarType::Char) - { - const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym8_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); + } + float* out_data = out.mutable_data_ptr(); + + if (is_asym_dequant) { + if (input.scalar_type() == ScalarType::Byte) { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8u_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == ScalarType::Char) { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym8_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == (ScalarType)Ushort) { + const uint16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym16u_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym16_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == (ScalarType)Bits4u) { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym4u_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else if (input.scalar_type() == (ScalarType)Bits4) { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_asym4_f32( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + zero_point_data, + scale_data); + } else { + if (axis == NULL) { +// calculate the dequantized output, cast scale to float to match fbgemm +// behavior +#define ASYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + out_data_ptr[i] = static_cast( \ + (input_data_ptr[i] - static_cast(*zero_point_data)) * \ + static_cast(*scale_data)); \ + } \ + } break; +#define ASYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_TESNOR); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR); + ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } - else if (input.scalar_type() == (ScalarType) Ushort) - { - const uint16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym16u_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); - } - else if (input.scalar_type() == ScalarType::Short) - { - const int16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym16_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); - } - else if (input.scalar_type() == (ScalarType) Bits4u) - { - const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym4u_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); - } - else if (input.scalar_type() == (ScalarType) Bits4) - { - const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_asym4_f32( - out_data, input_data, inp_shape, input.dim(), axis, - zero_point_data, scale_data); +#undef ASYM_CALCULATE_INT_TYPE_TENSOR +#undef ASYM_DEQUANTIZE_IMPL_TESNOR + } else { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) { + if (i < *axis) { + dims[i] = i; + } else { + dims[i] = i + 1; + } } - else - { - if(axis == NULL) - { - // calculate the dequantized output, cast scale to float to match fbgemm - // behavior - #define ASYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - out_data_ptr[i] = static_cast( \ - (input_data_ptr[i] - static_cast(*zero_point_data)) * \ - static_cast(*scale_data)); \ - } \ - } break; - #define ASYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_TESNOR); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - switch (input.scalar_type()) { - ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_TENSOR); - ASYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - #undef ASYM_CALCULATE_INT_TYPE_TENSOR - #undef ASYM_DEQUANTIZE_IMPL_TESNOR - } - else - { - // a list contains all dimensions except axis - int64_t dims[input.dim() - 1]; - for (int64_t i = 0; i < input.dim() - 1; i++) - { - if (i < *axis) - { - dims[i] = i; - } - else - { - dims[i] = i + 1; - } - } - - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual dequantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are dequantizing. - // in other words you are dequantizing in_data[in_ix] - #define ASYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - if (input.dim() == 1) { \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - ET_CHECK_MSG( \ - *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ - const optional dim; \ - torch::executor::apply_over_dim( \ - [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ - size_t numel, size_t stride, size_t base_ix) { \ - for (size_t i = 0; i < numel; i++) { \ - size_t current_ix = base_ix * stride + i; \ - float _scale = scale_data[current_ix]; \ - int64_t zero_point = 0; \ - if (zero_point_data != nullptr) { \ - zero_point = zero_point_data[current_ix]; \ - } \ - out_data_ptr[current_ix] = \ - static_cast( \ - input_data_ptr[current_ix] - zero_point) * \ - _scale; \ - } \ - }, \ - input, \ - dim); \ - break; \ - } \ - for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ - float _scale = scale_data[channel_ix]; \ - int64_t _zero_point = 0; \ - if (zero_point_data != nullptr) { \ - _zero_point = zero_point_data[channel_ix]; \ - } \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - torch::executor::apply_over_dim_list( \ - [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ - out_data_ptr[in_ix] = static_cast( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ - } \ - break; - #define ASYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_CHANNEL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - switch (input.scalar_type()) { - ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL); - ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - #undef ASYM_CALCULATE_INT_TYPE_CHANNEL - #undef ASYM_DEQUANTIZE_IMPL_CHANNEL - } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + +// Actual dequantization logic +// input, out are the input and output tensors +// channel_ix is the index along the axis dimension. 0 <= channel_ix < +// input.size(axis). +// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix +// will be 0, 1, 2, ... C-1 +// in_ix is the flat index of the element you are dequantizing. +// in other words you are dequantizing in_data[in_ix] +#define ASYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + if (input.dim() == 1) { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + ET_CHECK_MSG( \ + *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ + const optional dim; \ + torch::executor::apply_over_dim( \ + [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ + size_t numel, size_t stride, size_t base_ix) { \ + for (size_t i = 0; i < numel; i++) { \ + size_t current_ix = base_ix * stride + i; \ + float _scale = scale_data[current_ix]; \ + int64_t zero_point = 0; \ + if (zero_point_data != nullptr) { \ + zero_point = zero_point_data[current_ix]; \ + } \ + out_data_ptr[current_ix] = \ + static_cast( \ + input_data_ptr[current_ix] - zero_point) * \ + _scale; \ + } \ + }, \ + input, \ + dim); \ + break; \ + } \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); \ + ++channel_ix) { \ + float _scale = scale_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ + out_data_ptr[in_ix] = static_cast( \ + (input_data_ptr[in_ix] - _zero_point) * _scale); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; +#define ASYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, ASYM_DEQUANTIZE_IMPL_CHANNEL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(ASYM_CALCULATE_INT_TYPE_CHANNEL); + ASYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } +#undef ASYM_CALCULATE_INT_TYPE_CHANNEL +#undef ASYM_DEQUANTIZE_IMPL_CHANNEL + } } - else - { - if (input.scalar_type() == ScalarType::Byte) - { - const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym8u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } - else if (input.scalar_type() == ScalarType::Char) - { - const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym8_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } - else if (input.scalar_type() == (ScalarType) Ushort) - { - const uint16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym16u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } - else if (input.scalar_type() == ScalarType::Short) - { - const int16_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym16_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); - } - else if (input.scalar_type() == (ScalarType) Bits4u) - { - const uint8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym4u_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else { + if (input.scalar_type() == ScalarType::Byte) { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym8u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == ScalarType::Char) { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym8_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == (ScalarType)Ushort) { + const uint16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym16u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == ScalarType::Short) { + const int16_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym16_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == (ScalarType)Bits4u) { + const uint8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym4u_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else if (input.scalar_type() == (ScalarType)Bits4) { + const int8_t* input_data = input.const_data_ptr(); + xa_nn_elm_dequantize_sym4_f32( + out_data, input_data, inp_shape, input.dim(), axis, scale_data); + } else { + if (axis == NULL) { +// calculate the dequantized output, cast scale to float to match fbgemm +// behavior +#define SYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + out_data_ptr[i] = static_cast( \ + (input_data_ptr[i] - static_cast(*zero_point_data)) * \ + static_cast(*scale_data)); \ + } \ + } break; +#define SYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_TESNOR); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR); + SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } - else if (input.scalar_type() == (ScalarType) Bits4) - { - const int8_t* input_data = input.const_data_ptr(); - xa_nn_elm_dequantize_sym4_f32( - out_data, input_data, inp_shape, input.dim(), axis, scale_data); +#undef SYM_DEQUANTIZE_IMPL_TESNOR +#undef SYM_CALCULATE_INT_TYPE_TENSOR + } else { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) { + if (i < *axis) { + dims[i] = i; + } else { + dims[i] = i + 1; + } } - else - { - if(axis == NULL) - { - // calculate the dequantized output, cast scale to float to match fbgemm - // behavior - #define SYM_DEQUANTIZE_IMPL_TESNOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - out_data_ptr[i] = static_cast( \ - (input_data_ptr[i] - static_cast(*zero_point_data)) * \ - static_cast(*scale_data)); \ - } \ - } break; - #define SYM_CALCULATE_INT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_TESNOR); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - switch (input.scalar_type()) { - ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_TENSOR); - SYM_CALCULATE_INT_TYPE_TENSOR(uint16_t, Bits16); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - #undef SYM_DEQUANTIZE_IMPL_TESNOR - #undef SYM_CALCULATE_INT_TYPE_TENSOR - } - else - { - // a list contains all dimensions except axis - int64_t dims[input.dim() - 1]; - for (int64_t i = 0; i < input.dim() - 1; i++) - { - if (i < *axis) - { - dims[i] = i; - } - else - { - dims[i] = i + 1; - } - } - - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual dequantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are dequantizing. - // in other words you are dequantizing in_data[in_ix] - #define SYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - if (input.dim() == 1) { \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - ET_CHECK_MSG( \ - *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ - const optional dim; \ - torch::executor::apply_over_dim( \ - [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ - size_t numel, size_t stride, size_t base_ix) { \ - for (size_t i = 0; i < numel; i++) { \ - size_t current_ix = base_ix * stride + i; \ - float _scale = scale_data[current_ix]; \ - int64_t zero_point = 0; \ - if (zero_point_data != nullptr) { \ - zero_point = zero_point_data[current_ix]; \ - } \ - out_data_ptr[current_ix] = \ - static_cast( \ - input_data_ptr[current_ix] - zero_point) * \ - _scale; \ - } \ - }, \ - input, \ - dim); \ - break; \ - } \ - for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ - float _scale = scale_data[channel_ix]; \ - int64_t _zero_point = 0; \ - if (zero_point_data != nullptr) { \ - _zero_point = zero_point_data[channel_ix]; \ - } \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - torch::executor::apply_over_dim_list( \ - [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ - out_data_ptr[in_ix] = static_cast( \ - (input_data_ptr[in_ix] - _zero_point) * _scale); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ - } \ - break; - #define SYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_CHANNEL); \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - switch (input.scalar_type()) { - ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL); - SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - #undef SYM_DEQUANTIZE_IMPL_CHANNEL - #undef SYM_CALCULATE_INT_TYPE_CHANNEL - } + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + +// Actual dequantization logic +// input, out are the input and output tensors +// channel_ix is the index along the axis dimension. 0 <= channel_ix < +// input.size(axis). +// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix +// will be 0, 1, 2, ... C-1 +// in_ix is the flat index of the element you are dequantizing. +// in other words you are dequantizing in_data[in_ix] +#define SYM_DEQUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + if (input.dim() == 1) { \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + ET_CHECK_MSG( \ + *axis == 0, "Axis must be 0 for a single dimensional tensors"); \ + const optional dim; \ + torch::executor::apply_over_dim( \ + [input_data_ptr, out_data_ptr, zero_point_data, scale_data]( \ + size_t numel, size_t stride, size_t base_ix) { \ + for (size_t i = 0; i < numel; i++) { \ + size_t current_ix = base_ix * stride + i; \ + float _scale = scale_data[current_ix]; \ + int64_t zero_point = 0; \ + if (zero_point_data != nullptr) { \ + zero_point = zero_point_data[current_ix]; \ + } \ + out_data_ptr[current_ix] = \ + static_cast( \ + input_data_ptr[current_ix] - zero_point) * \ + _scale; \ + } \ + }, \ + input, \ + dim); \ + break; \ + } \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); \ + ++channel_ix) { \ + float _scale = scale_data[channel_ix]; \ + int64_t _zero_point = 0; \ + if (zero_point_data != nullptr) { \ + _zero_point = zero_point_data[channel_ix]; \ + } \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, out_data_ptr, _scale, _zero_point](size_t in_ix) { \ + out_data_ptr[in_ix] = static_cast( \ + (input_data_ptr[in_ix] - _zero_point) * _scale); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; +#define SYM_CALCULATE_INT_TYPE_CHANNEL(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_FLOAT_TYPES_WITH(IN_CTYPE, SYM_DEQUANTIZE_IMPL_CHANNEL); \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + switch (input.scalar_type()) { + ET_FORALL_INT_TYPES(SYM_CALCULATE_INT_TYPE_CHANNEL); + SYM_CALCULATE_INT_TYPE_CHANNEL(uint16_t, Bits16); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } +#undef SYM_DEQUANTIZE_IMPL_CHANNEL +#undef SYM_CALCULATE_INT_TYPE_CHANNEL + } } + } } /** @@ -511,56 +490,50 @@ Tensor& dequantize_per_tensor_out( int64_t quant_max, ScalarType dtype, exec_aten::optional out_dtype, - Tensor& out) -{ - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( + Tensor& out) { + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_tensor_out"); - check_dequantize_per_tensor_args( + check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); - - float scale_data = (float)scale; - int zero_point_data = (int)zero_point; - - Dequantize_impl(out, - input, - &scale_data, - &zero_point_data, - NULL, - out_dtype); - - return out; + + float scale_data = (float)scale; + int zero_point_data = (int)zero_point; + + Dequantize_impl(out, input, &scale_data, &zero_point_data, NULL, out_dtype); + + return out; } -Tensor& dequantize_per_tensor_tensor_args_out(const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - exec_aten::optional out_dtype, - Tensor& out) -{ - ET_CHECK_MSG( +Tensor& dequantize_per_tensor_tensor_args_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) { + ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "Expected scale to be Double tensor received: %" PRId8, static_cast(scale.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( zero_point.scalar_type() == ScalarType::Long, "Expected scale to be Long tensor received: %" PRId8, static_cast(zero_point.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( scale.numel() == 1, "Exepcted scale to only have one element received: %zd", ssize_t(scale.numel())); - ET_CHECK_MSG( + ET_CHECK_MSG( zero_point.numel() == 1, "Exepcted zero_point to only have one element received: %zd", ssize_t(zero_point.numel())); - dequantize_per_tensor_out( + dequantize_per_tensor_out( input, scale.const_data_ptr()[0], zero_point.const_data_ptr()[0], @@ -570,49 +543,48 @@ Tensor& dequantize_per_tensor_tensor_args_out(const Tensor& input, out_dtype, out); - return out; + return out; } -Tensor& dequantize_per_channel_out(const Tensor& input, - const Tensor& scale, - const exec_aten::optional& opt_zero_points, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - exec_aten::optional out_dtype, - Tensor& out) -{ - torch::executor::Error err = resize_tensor(out, input.sizes()); - - // normalize axis - ET_CHECK_MSG( +Tensor& dequantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const exec_aten::optional& opt_zero_points, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + exec_aten::optional out_dtype, + Tensor& out) { + torch::executor::Error err = resize_tensor(out, input.sizes()); + + // normalize axis + ET_CHECK_MSG( executorch::runtime::tensor_has_dim(input, axis), "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", ssize_t(axis), ssize_t(input.dim())); - if (axis < 0) - { - axis += executorch::runtime::nonzero_dim(input); - } + if (axis < 0) { + axis += executorch::runtime::nonzero_dim(input); + } - ET_CHECK_MSG( + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_channel_out"); - ET_CHECK_MSG( + ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "scale.scalar_type() %" PRId8 " is not double type", static_cast(scale.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( scale.numel() == input.size(axis), "scale.numel() %zd != input.size(axis) %zd", ssize_t(scale.numel()), ssize_t(input.size(axis))); - if (opt_zero_points.has_value()) { + if (opt_zero_points.has_value()) { auto zero_point = opt_zero_points.value(); ET_CHECK_MSG( zero_point.scalar_type() == ScalarType::Long, @@ -624,41 +596,31 @@ Tensor& dequantize_per_channel_out(const Tensor& input, "zero_point.numel() %zd != input.size(axis) %zd", ssize_t(zero_point.numel()), ssize_t(input.size(axis))); - } + } - check_dequantize_per_tensor_args( + check_dequantize_per_tensor_args( input, quant_min, quant_max, dtype, out_dtype, out); - - int *axis_ptr = (int *)&axis; - - const double* scale_dt = scale.const_data_ptr(); - const int64_t* zero_point_dt; - int zero_point_data[input.size(axis)]; - int *zero_point_ptr; - if (opt_zero_points.has_value()) - { - zero_point_dt = opt_zero_points.value().const_data_ptr(); - zero_point_ptr = &zero_point_data[0]; - for(int i = 0; i < scale.numel(); i++) - { - zero_point_ptr[i] = (int)zero_point_dt[i]; - } - } - else - { - zero_point_ptr = nullptr; - } - float scale_data[input.size(axis)]; - for(int i = 0; i < scale.numel(); i++) - { - scale_data[i] = (float)scale_dt[i]; + + int* axis_ptr = (int*)&axis; + + const double* scale_dt = scale.const_data_ptr(); + const int64_t* zero_point_dt; + int zero_point_data[input.size(axis)]; + int* zero_point_ptr; + if (opt_zero_points.has_value()) { + zero_point_dt = opt_zero_points.value().const_data_ptr(); + zero_point_ptr = &zero_point_data[0]; + for (int i = 0; i < scale.numel(); i++) { + zero_point_ptr[i] = (int)zero_point_dt[i]; } - Dequantize_impl(out, - input, - scale_data, - zero_point_ptr, - axis_ptr, - out_dtype); + } else { + zero_point_ptr = nullptr; + } + float scale_data[input.size(axis)]; + for (int i = 0; i < scale.numel(); i++) { + scale_data[i] = (float)scale_dt[i]; + } + Dequantize_impl(out, input, scale_data, zero_point_ptr, axis_ptr, out_dtype); return out; } @@ -673,14 +635,13 @@ Tensor& dequantize_per_channel_out( int64_t quant_max, ScalarType dtype, exec_aten::optional out_dtype, - Tensor& out) -{ - (void)context; - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( + Tensor& out) { + (void)context; + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_channel_out"); - return dequantize_per_channel_out( + return dequantize_per_channel_out( input, scale, opt_zero_points, @@ -701,12 +662,11 @@ Tensor& dequantize_per_tensor_out( int64_t quant_max, ScalarType dtype, exec_aten::optional out_dtype, - Tensor& out) -{ - // TODO(larryliu): Add a context arg to the real op function and remove this - // wrapper - (void)context; - return dequantize_per_tensor_out( + Tensor& out) { + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + return dequantize_per_tensor_out( input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } @@ -719,12 +679,11 @@ Tensor& dequantize_per_tensor_tensor_args_out( int64_t quant_max, ScalarType dtype, exec_aten::optional out_dtype, - Tensor& out) -{ - // TODO(larryliu): Add a context arg to the real op function and remove this - // wrapper - (void)context; - return dequantize_per_tensor_tensor_args_out( + Tensor& out) { + // TODO(larryliu): Add a context arg to the real op function and remove this + // wrapper + (void)context; + return dequantize_per_tensor_tensor_args_out( input, scale, zero_point, quant_min, quant_max, dtype, out_dtype, out); } @@ -736,47 +695,46 @@ Tensor& dequantize_per_token_out( int64_t quant_max, ScalarType dtype, ScalarType out_dtype, - Tensor& out) -{ - // Refactor this into a util - size_t num_channels = 1; - for (size_t i = 0; i < input.dim() - 1; i++) - { - num_channels *= input.size(i); - } - // This unfortunate change is needed because we compile op_quantize for aten - // mode as well - std::array input_sizes; - input_sizes[0] = static_cast(num_channels); - input_sizes[1] = + Tensor& out) { + // Refactor this into a util + size_t num_channels = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + num_channels *= input.size(i); + } + // This unfortunate change is needed because we compile op_quantize for aten + // mode as well + std::array input_sizes; + input_sizes[0] = static_cast(num_channels); + input_sizes[1] = static_cast(input.size(input.dim() - 1)); #ifdef USE_ATEN_LIB - Tensor reshaped_input = at::from_blob( + Tensor reshaped_input = at::from_blob( input.mutable_data_ptr(), input_sizes, at::TensorOptions(input.scalar_type())); #else - std::array input_dim_order{0, 1}; - std::array input_strides; - executorch::runtime::dim_order_to_stride_nocheck( + std::array input_dim_order{0, 1}; + std::array input_strides; + executorch::runtime::dim_order_to_stride_nocheck( input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); - void* input_data = input.mutable_data_ptr(); - torch::executor::TensorImpl reshaped_input_impl = executorch::runtime::etensor::TensorImpl( - input.scalar_type(), - 2, - input_sizes.data(), - input_data, - input_dim_order.data(), - input_strides.data(), - executorch::runtime::TensorShapeDynamism::STATIC); - Tensor reshaped_input(&reshaped_input_impl); - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( + void* input_data = input.mutable_data_ptr(); + torch::executor::TensorImpl reshaped_input_impl = + executorch::runtime::etensor::TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + executorch::runtime::TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in dequantize_per_channel_out"); #endif - return dequantize_per_channel_out( + return dequantize_per_channel_out( reshaped_input, scale, zero_points, @@ -797,8 +755,7 @@ Tensor& dequantize_per_token_out( int64_t quant_max, ScalarType dtype, ScalarType out_dtype, - Tensor& out) -{ + Tensor& out) { (void)context; return dequantize_per_token_out( input, scale, zero_points, quant_min, quant_max, dtype, out_dtype, out); diff --git a/backends/cadence/fusion_g3/operators/op_quantize.cpp b/backends/cadence/fusion_g3/operators/op_quantize.cpp index bc84829edb..2b8376dc8d 100644 --- a/backends/cadence/fusion_g3/operators/op_quantize.cpp +++ b/backends/cadence/fusion_g3/operators/op_quantize.cpp @@ -8,10 +8,10 @@ #include #include +#include #include #include #include -#include using exec_aten::Scalar; using exec_aten::ScalarType; @@ -21,14 +21,10 @@ using torch::executor::KernelRuntimeContext; /* ScalarType in Executorch do not have support for below data types. * So, creating a placeholder for these data types. Once, ScalarTypes is - * updated to have support for below data types, these can be removed and + * updated to have support for below data types, these can be removed and * operator need to be updated accordingly */ - enum datatype { - Ushort = 20, - Bits4u = 21, - Bits4 = 22 - }; +enum datatype { Ushort = 20, Bits4u = 21, Bits4 = 22 }; /** * For an input tensor, use the scale and zero_point arguments to quantize it. @@ -38,102 +34,84 @@ namespace impl { namespace FusionG3 { namespace native { - namespace { /** * Asserts that the parameters are valid. */ -void check_quantize_per_tensor_args(const Tensor& input, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - // Ensure self and out has the same shape - ET_CHECK_MSG( +void check_quantize_per_tensor_args( + const Tensor& input, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + // Ensure self and out has the same shape + ET_CHECK_MSG( torch::executor::isFloatingType(input.scalar_type()), "input.scalar_type() %" PRId8 " is not floating type", static_cast(input.scalar_type())); - int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; - ScalarType out_dtype = out.scalar_type(); - ET_CHECK_MSG( + int32_t quant_min_lower_bound = 0, quant_max_upper_bound = 0; + ScalarType out_dtype = out.scalar_type(); + ET_CHECK_MSG( out_dtype == dtype, "out.scalar_type() %" PRId8 " is not matching dtype argument %" PRId8, static_cast(out_dtype), static_cast(dtype)); - if (out_dtype == ScalarType::Byte) - { - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - } - else if (dtype == ScalarType::Char) - { - quant_min_lower_bound = - static_cast(std::numeric_limits::min()); - quant_max_upper_bound = - static_cast(std::numeric_limits::max()); - } - else if (dtype == ScalarType::Bits16) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - } - else if (dtype == ScalarType::Short) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - } - else if (dtype == (ScalarType)Ushort) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - } - else if (dtype == (ScalarType)Bits4u) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - /* Minimum and maximum values fo unsigned 4-bit data type */ - quant_min_lower_bound = quant_min_lower_bound >> 4; - quant_max_upper_bound = quant_max_upper_bound >> 4; - } - else if (dtype == (ScalarType)Bits4) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - /* Minimum and maximum values fo signed 4-bit data type */ - quant_min_lower_bound = quant_min_lower_bound >> 4; - quant_max_upper_bound = quant_max_upper_bound >> 4; - } - else if (dtype == ScalarType::Int) - { - quant_min_lower_bound = std::numeric_limits::min(); - quant_max_upper_bound = std::numeric_limits::max(); - } - else - { - ET_CHECK_MSG( - false, "Unsupported dtype: %" PRId8, static_cast(out_dtype)); - } - + if (out_dtype == ScalarType::Byte) { + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + } else if (dtype == ScalarType::Char) { + quant_min_lower_bound = + static_cast(std::numeric_limits::min()); + quant_max_upper_bound = + static_cast(std::numeric_limits::max()); + } else if (dtype == ScalarType::Bits16) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } else if (dtype == ScalarType::Short) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } else if (dtype == (ScalarType)Ushort) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } else if (dtype == (ScalarType)Bits4u) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + /* Minimum and maximum values fo unsigned 4-bit data type */ + quant_min_lower_bound = quant_min_lower_bound >> 4; + quant_max_upper_bound = quant_max_upper_bound >> 4; + } else if (dtype == (ScalarType)Bits4) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + /* Minimum and maximum values fo signed 4-bit data type */ + quant_min_lower_bound = quant_min_lower_bound >> 4; + quant_max_upper_bound = quant_max_upper_bound >> 4; + } else if (dtype == ScalarType::Int) { + quant_min_lower_bound = std::numeric_limits::min(); + quant_max_upper_bound = std::numeric_limits::max(); + } else { ET_CHECK_MSG( + false, "Unsupported dtype: %" PRId8, static_cast(out_dtype)); + } + + ET_CHECK_MSG( quant_min >= quant_min_lower_bound, "quant_min out of bound for dtype, expected quant_min_lower_bound: %" PRId32 " actual quant_min: %" PRId64, quant_min_lower_bound, quant_min); - ET_CHECK_MSG( + ET_CHECK_MSG( quant_max <= quant_max_upper_bound, "quant_max out of bound for dtype, expected quant_max_upper_bound: %" PRId32 " actual quant_max: %" PRId64, quant_max_upper_bound, quant_max); -}/* check_quantize_per_tensor_args */ +} /* check_quantize_per_tensor_args */ } // namespace @@ -143,8 +121,7 @@ T quantize_val( int64_t zero_point, K value, int64_t quant_min, - int64_t quant_max) -{ + int64_t quant_max) { int64_t qvalue; float inv_scale = 1.0f / static_cast(scale); qvalue = static_cast( @@ -156,458 +133,495 @@ T quantize_val( return static_cast(qvalue); } - /* Local function which calls the kernels based on the output datatype */ -void quantize_impl(Tensor& out, - const Tensor& input, - float *scale_data, - int *zero_point_data, - int *axis, - int quant_min, - int quant_max) -{ - const exec_aten::ArrayRef input_size = input.sizes(); +void quantize_impl( + Tensor& out, + const Tensor& input, + float* scale_data, + int* zero_point_data, + int* axis, + int quant_min, + int quant_max) { + const exec_aten::ArrayRef input_size = input.sizes(); - int kTensorDimensionLimit = 5; + int kTensorDimensionLimit = 5; - int inp_shape[kTensorDimensionLimit]; + int inp_shape[kTensorDimensionLimit]; - for(auto i = 0; i < input_size.size(); i++) - { - inp_shape[i] = input_size[i]; - } - - const float* input_data = input.const_data_ptr(); + for (auto i = 0; i < input_size.size(); i++) { + inp_shape[i] = input_size[i]; + } - bool is_asym_quant = 0; + const float* input_data = input.const_data_ptr(); - if(zero_point_data != NULL) //asymmetric quant + bool is_asym_quant = 0; + + if (zero_point_data != NULL) // asymmetric quant + { + if (axis != NULL) // channel { - if(axis != NULL) //channel - { - for(int i = 0; i < input.size(*axis) ; i++) - { - if(zero_point_data[i] != 0) - { - is_asym_quant |= 1; - } + for (int i = 0; i < input.size(*axis); i++) { + if (zero_point_data[i] != 0) { + is_asym_quant |= 1; } } - else + } else { + if (*zero_point_data != 0) // tensor { - if(*zero_point_data != 0) //tensor - { - is_asym_quant |= 1; - } + is_asym_quant |= 1; } } - - if(is_asym_quant) - { - if (out.scalar_type() == ScalarType::Byte) - { - uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym8u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); + } + + if (is_asym_quant) { + if (out.scalar_type() == ScalarType::Byte) { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == ScalarType::Char) { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym8( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Ushort) { + uint16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym16u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym16( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Bits4u) { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym4u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Bits4) { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_asym4( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + zero_point_data, + quant_min, + quant_max); + } else { + if (axis == NULL) { + // Vector quantization +// calculate the quantized input +#define ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + (double)*scale_data, \ + (int64_t) * zero_point_data, \ + value, \ + (int64_t)quant_min, \ + (int64_t)quant_max); \ + } \ + } break; +#define ASYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(IN_CTYPE, ASYM_QUANTIZE_IMPL_TENSOR); \ + ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_TENSOR); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } - else if (out.scalar_type() == ScalarType::Char) - { - int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym8( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType)Ushort) - { - uint16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym16u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); - } - else if (out.scalar_type() == ScalarType::Short) - { - int16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym16( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType)Bits4u) - { - uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym4u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType)Bits4) - { - int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_asym4( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, zero_point_data, quant_min, quant_max); + + } else { + // Channel based quantization + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) { + if (i < *axis) { + dims[i] = i; + } else { + dims[i] = i + 1; + } } - else - { - if(axis == NULL) - { - // Vector quantization - // calculate the quantized input - #define ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - IN_CTYPE value = input_data_ptr[i]; \ - out_data_ptr[i] = quantize_val( \ - (double)*scale_data, (int64_t)*zero_point_data, value, \ - (int64_t)quant_min, (int64_t)quant_max); \ - } \ - } break; - #define ASYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_INT_TYPES_WITH(IN_CTYPE, ASYM_QUANTIZE_IMPL_TENSOR); \ - ASYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - - switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_TENSOR); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - - } - else - { - // Channel based quantization - // a list contains all dimensions except axis - int64_t dims[input.dim() - 1]; - for (int64_t i = 0; i < input.dim() - 1; i++) - { - if (i < *axis) - { - dims[i] = i; - } - else - { - dims[i] = i + 1; - } - } - - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] - #define ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ - double _scale = (double)scale_data[channel_ix]; \ - int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - torch::executor::apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ - } \ - break; - #define ASYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_INT_TYPES_WITH(CTYPE_IN, ASYM_QUANTIZE_IMPL_CHANNEL); \ - ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - - switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_CHANNEL); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - } - - #undef ASYM_CALCULATE_FLOAT_TYPE_TENSOR - #undef ASYM_CALCULATE_FLOAT_TYPE_CHANNEL - #undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR - #undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + +// Actual quantization logic +// input, out are the input and output tensors +// channel_ix is the index along the axis dimension. 0 <= channel_ix < +// input.size(axis). +// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix +// will be 0, 1, 2, ... C-1 +// in_ix is the flat index of the element you are quantizing. +// in other words you are quantizing in_data[in_ix] +#define ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); \ + ++channel_ix) { \ + double _scale = (double)scale_data[channel_ix]; \ + int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, \ + out_data_ptr, \ + _scale, \ + _zero_point, \ + quant_min, \ + quant_max](size_t in_ix) { \ + out_data_ptr[in_ix] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[in_ix], \ + quant_min, \ + quant_max); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; +#define ASYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(CTYPE_IN, ASYM_QUANTIZE_IMPL_CHANNEL); \ + ASYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(ASYM_CALCULATE_FLOAT_TYPE_CHANNEL); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } + } + +#undef ASYM_CALCULATE_FLOAT_TYPE_TENSOR +#undef ASYM_CALCULATE_FLOAT_TYPE_CHANNEL +#undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR +#undef ASYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL } - else - { - if (out.scalar_type() == ScalarType::Byte) - { - uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym8u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); - } - else if (out.scalar_type() == ScalarType::Char) - { - int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym8( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType) Ushort) - { - uint16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym16u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); + } else { + if (out.scalar_type() == ScalarType::Byte) { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym8u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == ScalarType::Char) { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym8( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Ushort) { + uint16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym16u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == ScalarType::Short) { + int16_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym16( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Bits4u) { + uint8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym4u( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else if (out.scalar_type() == (ScalarType)Bits4) { + int8_t* out_data = out.mutable_data_ptr(); + xa_nn_elm_quantize_f32_sym4( + out_data, + input_data, + inp_shape, + input.dim(), + axis, + scale_data, + quant_min, + quant_max); + } else { + if (axis == NULL) { + // calculate the quantized input +#define SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ + case ScalarType::out_dtype: { \ + /* Hoist these function calls out of our inner loop because they might not \ + * get inlined without LTO, particularly in ATen mode. */ \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + const auto input_numel = input.numel(); \ + for (size_t i = 0; i < input_numel; i++) { \ + IN_CTYPE value = input_data_ptr[i]; \ + out_data_ptr[i] = quantize_val( \ + (double)*scale_data, \ + (int64_t) * zero_point_data, \ + value, \ + (int64_t)quant_min, \ + (int64_t)quant_max); \ + } \ + } break; +#define SYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(IN_CTYPE, SYM_QUANTIZE_IMPL_TENSOR); \ + SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_TENSOR); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } - else if (out.scalar_type() == ScalarType::Short) - { - int16_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym16( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType) Bits4u) - { - uint8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym4u( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); - } - else if (out.scalar_type() == (ScalarType) Bits4) - { - int8_t* out_data = out.mutable_data_ptr(); - xa_nn_elm_quantize_f32_sym4( - out_data, input_data, inp_shape, input.dim(), axis, - scale_data, quant_min, quant_max); + + } else { + // a list contains all dimensions except axis + int64_t dims[input.dim() - 1]; + for (int64_t i = 0; i < input.dim() - 1; i++) { + if (i < *axis) { + dims[i] = i; + } else { + dims[i] = i + 1; + } } - else - { - if(axis == NULL) - { - // calculate the quantized input - #define SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, OUT_CTYPE, out_dtype) \ - case ScalarType::out_dtype: { \ - /* Hoist these function calls out of our inner loop because they might not \ - * get inlined without LTO, particularly in ATen mode. */ \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - const auto input_numel = input.numel(); \ - for (size_t i = 0; i < input_numel; i++) { \ - IN_CTYPE value = input_data_ptr[i]; \ - out_data_ptr[i] = quantize_val( \ - (double)*scale_data, (int64_t)*zero_point_data, value, \ - (int64_t)quant_min, (int64_t)quant_max); \ - } \ - } break; - #define SYM_CALCULATE_FLOAT_TYPE_TENSOR(IN_CTYPE, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_INT_TYPES_WITH(IN_CTYPE, SYM_QUANTIZE_IMPL_TENSOR); \ - SYM_QUANTIZE_IMPL_TENSOR(IN_CTYPE, uint16_t, Bits16) \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - - switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_TENSOR); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - - } - else - { - // a list contains all dimensions except axis - int64_t dims[input.dim() - 1]; - for (int64_t i = 0; i < input.dim() - 1; i++) - { - if (i < *axis) - { - dims[i] = i; - } - else - { - dims[i] = i + 1; - } - } - - exec_aten::optional> optional_dim_list{ - exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; - - // Actual quantization logic - // input, out are the input and output tensors - // channel_ix is the index along the axis dimension. 0 <= channel_ix < - // input.size(axis). - // i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix - // will be 0, 1, 2, ... C-1 - // in_ix is the flat index of the element you are quantizing. - // in other words you are quantizing in_data[in_ix] - #define SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ - case ScalarType::out_dtype: \ - for (size_t channel_ix = 0; channel_ix < input.size(*axis); ++channel_ix) { \ - double _scale = (double)scale_data[channel_ix]; \ - int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ - auto* out_data_ptr = out.mutable_data_ptr(); \ - const auto* input_data_ptr = input.const_data_ptr(); \ - torch::executor::apply_over_dim_list( \ - [input_data_ptr, \ - out_data_ptr, \ - _scale, \ - _zero_point, \ - quant_min, \ - quant_max](size_t in_ix) { \ - out_data_ptr[in_ix] = quantize_val( \ - _scale, \ - _zero_point, \ - input_data_ptr[in_ix], \ - quant_min, \ - quant_max); \ - }, \ - input, \ - optional_dim_list, \ - channel_ix); \ - } \ - break; - #define SYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ - case ScalarType::in_dtype: \ - switch (out.scalar_type()) { \ - ET_FORALL_INT_TYPES_WITH(CTYPE_IN, SYM_QUANTIZE_IMPL_CHANNEL); \ - SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ - default: \ - ET_CHECK_MSG( \ - false, \ - "Unhandled output dtype %" PRId8, \ - static_cast(out.scalar_type())); \ - } \ - break; - - switch (input.scalar_type()) { - ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_CHANNEL); - default: - ET_CHECK_MSG( - false, - "Unhandled input dtype %" PRId8, - static_cast(input.scalar_type())); - } - } - #undef SYM_CALCULATE_FLOAT_TYPE_TENSOR - #undef SYM_CALCULATE_FLOAT_TYPE_CHANNEL - #undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR - #undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL + + exec_aten::optional> optional_dim_list{ + exec_aten::ArrayRef{dims, size_t(input.dim() - 1)}}; + +// Actual quantization logic +// input, out are the input and output tensors +// channel_ix is the index along the axis dimension. 0 <= channel_ix < +// input.size(axis). +// i.e. if the tensor has shape (N,C,H,W), axis being 1, then channel_ix +// will be 0, 1, 2, ... C-1 +// in_ix is the flat index of the element you are quantizing. +// in other words you are quantizing in_data[in_ix] +#define SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, CTYPE_OUT, out_dtype) \ + case ScalarType::out_dtype: \ + for (size_t channel_ix = 0; channel_ix < input.size(*axis); \ + ++channel_ix) { \ + double _scale = (double)scale_data[channel_ix]; \ + int64_t _zero_point = (int64_t)zero_point_data[channel_ix]; \ + auto* out_data_ptr = out.mutable_data_ptr(); \ + const auto* input_data_ptr = input.const_data_ptr(); \ + torch::executor::apply_over_dim_list( \ + [input_data_ptr, \ + out_data_ptr, \ + _scale, \ + _zero_point, \ + quant_min, \ + quant_max](size_t in_ix) { \ + out_data_ptr[in_ix] = quantize_val( \ + _scale, \ + _zero_point, \ + input_data_ptr[in_ix], \ + quant_min, \ + quant_max); \ + }, \ + input, \ + optional_dim_list, \ + channel_ix); \ + } \ + break; +#define SYM_CALCULATE_FLOAT_TYPE_CHANNEL(CTYPE_IN, in_dtype) \ + case ScalarType::in_dtype: \ + switch (out.scalar_type()) { \ + ET_FORALL_INT_TYPES_WITH(CTYPE_IN, SYM_QUANTIZE_IMPL_CHANNEL); \ + SYM_QUANTIZE_IMPL_CHANNEL(CTYPE_IN, uint16_t, Bits16) \ + default: \ + ET_CHECK_MSG( \ + false, \ + "Unhandled output dtype %" PRId8, \ + static_cast(out.scalar_type())); \ + } \ + break; + + switch (input.scalar_type()) { + ET_FORALL_FLOAT_TYPES(SYM_CALCULATE_FLOAT_TYPE_CHANNEL); + default: + ET_CHECK_MSG( + false, + "Unhandled input dtype %" PRId8, + static_cast(input.scalar_type())); } + } +#undef SYM_CALCULATE_FLOAT_TYPE_TENSOR +#undef SYM_CALCULATE_FLOAT_TYPE_CHANNEL +#undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_TENSOR +#undef SYM_ASYM_QUANTIZE_IMPL_CHANNEL_CHANNEL } + } } // Quantize the input tensor -Tensor& quantize_per_tensor_out(KernelRuntimeContext& context, - const Tensor& input, - double scale, - int64_t zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( +Tensor& quantize_per_tensor_out( + KernelRuntimeContext& context, + const Tensor& input, + double scale, + int64_t zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in quantize_per_tensor_out"); - check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - float scale_data = (float)scale; - int zero_point_data = (int)zero_point; - quantize_impl(out, - input, - &scale_data, - &zero_point_data, - NULL, - (int) quant_min, - (int) quant_max); + float scale_data = (float)scale; + int zero_point_data = (int)zero_point; + quantize_impl( + out, + input, + &scale_data, + &zero_point_data, + NULL, + (int)quant_min, + (int)quant_max); - return out; + return out; } - -Tensor& quantize_per_tensor_tensor_args_out(KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - // Temporary change to allow not fatal failure for now to unblock some - // expected failure tests that are dying instead of failure. Will revisit - // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal - // failures. - if (scale.scalar_type() != ScalarType::Double) - { - context.fail(torch::executor::Error::InvalidArgument); - return out; - } - ET_CHECK_MSG( +Tensor& quantize_per_tensor_tensor_args_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + // Temporary change to allow not fatal failure for now to unblock some + // expected failure tests that are dying instead of failure. Will revisit + // after ET_KERNEL_CHECK is fully implemented and properly allows non fatal + // failures. + if (scale.scalar_type() != ScalarType::Double) { + context.fail(torch::executor::Error::InvalidArgument); + return out; + } + ET_CHECK_MSG( scale.scalar_type() == ScalarType::Double, "Expected scale to be Double tensor received: %" PRId8, static_cast(scale.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( zero_point.scalar_type() == ScalarType::Long, "Expected zero_point to be Long tensor received: %" PRId8, static_cast(zero_point.scalar_type())); - ET_CHECK_MSG( + ET_CHECK_MSG( scale.numel() == 1, "Exepcted scale to only have one element received: %zd", ssize_t(scale.numel())); - ET_CHECK_MSG( + ET_CHECK_MSG( zero_point.numel() == 1, "Exepcted zero_point to only have one element received: %zd", ssize_t(zero_point.numel())); - quantize_per_tensor_out(context, + quantize_per_tensor_out( + context, input, scale.const_data_ptr()[0], zero_point.const_data_ptr()[0], @@ -616,113 +630,111 @@ Tensor& quantize_per_tensor_tensor_args_out(KernelRuntimeContext& context, dtype, out); - return out; + return out; } -Tensor& quantize_per_tensor_tensor_args_out(const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - auto context = torch::executor::RuntimeContext(); - auto& res = quantize_per_tensor_tensor_args_out( +Tensor& quantize_per_tensor_tensor_args_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + auto context = torch::executor::RuntimeContext(); + auto& res = quantize_per_tensor_tensor_args_out( context, input, scale, zero_point, quant_min, quant_max, dtype, out); - ET_CHECK(context.failure_state() == Error::Ok); - return res; + ET_CHECK(context.failure_state() == Error::Ok); + return res; } -Tensor& quantize_per_channel_out(const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - torch::executor::Error err = resize_tensor(out, input.sizes()); - - // normalize axis - ET_CHECK_MSG( - executorch::runtime::tensor_has_dim(input, axis), - "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", - ssize_t(axis), - ssize_t(input.dim())); +Tensor& quantize_per_channel_out( + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + torch::executor::Error err = resize_tensor(out, input.sizes()); - if (axis < 0) - { - axis += executorch::runtime::nonzero_dim(input); - } + // normalize axis + ET_CHECK_MSG( + executorch::runtime::tensor_has_dim(input, axis), + "axis %zd is not legal it should be -input.dim() <= axis < input.dim() %zd", + ssize_t(axis), + ssize_t(input.dim())); - ET_CHECK_MSG( - err == torch::executor::Error::Ok, - "Failed to resize out Tensor in quantize_per_channel_out"); + if (axis < 0) { + axis += executorch::runtime::nonzero_dim(input); + } - ET_CHECK_MSG( - scale.scalar_type() == ScalarType::Double, - "scale.scalar_type() %" PRId8 " is not double type", - static_cast(scale.scalar_type())); + ET_CHECK_MSG( + err == torch::executor::Error::Ok, + "Failed to resize out Tensor in quantize_per_channel_out"); - ET_CHECK_MSG( - scale.numel() == input.size(axis), - "scale.numel() %zd != input.size(axis) %zd", - scale.numel(), - input.size(axis)); + ET_CHECK_MSG( + scale.scalar_type() == ScalarType::Double, + "scale.scalar_type() %" PRId8 " is not double type", + static_cast(scale.scalar_type())); - ET_CHECK_MSG( - zero_point.scalar_type() == ScalarType::Long, - "zero_point.scalar_type() %" PRId8 " is not integer type", - static_cast(zero_point.scalar_type())); + ET_CHECK_MSG( + scale.numel() == input.size(axis), + "scale.numel() %zd != input.size(axis) %zd", + scale.numel(), + input.size(axis)); - ET_CHECK_MSG( - zero_point.numel() == input.size(axis), - "zero_point.numel() %zd != input.size(axis) %zd", - zero_point.numel(), - input.size(axis)); + ET_CHECK_MSG( + zero_point.scalar_type() == ScalarType::Long, + "zero_point.scalar_type() %" PRId8 " is not integer type", + static_cast(zero_point.scalar_type())); - check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); + ET_CHECK_MSG( + zero_point.numel() == input.size(axis), + "zero_point.numel() %zd != input.size(axis) %zd", + zero_point.numel(), + input.size(axis)); + check_quantize_per_tensor_args(input, quant_min, quant_max, dtype, out); - const double* scale_dt = scale.const_data_ptr(); - const int64_t* zero_point_dt = zero_point.const_data_ptr(); - - float scale_data[input.size(axis)]; - int zero_point_data[input.size(axis)]; + const double* scale_dt = scale.const_data_ptr(); + const int64_t* zero_point_dt = zero_point.const_data_ptr(); - for(int i = 0; i < scale.numel(); i++) - { - scale_data[i] = (float)scale_dt[i]; - zero_point_data[i] = (int)zero_point_dt[i]; - } + float scale_data[input.size(axis)]; + int zero_point_data[input.size(axis)]; - int *axis_ptr = (int *)&axis; + for (int i = 0; i < scale.numel(); i++) { + scale_data[i] = (float)scale_dt[i]; + zero_point_data[i] = (int)zero_point_dt[i]; + } - quantize_impl(out, - input, - scale_data, - zero_point_data, - axis_ptr, - (int) quant_min, - (int) quant_max); + int* axis_ptr = (int*)&axis; - return out; + quantize_impl( + out, + input, + scale_data, + zero_point_data, + axis_ptr, + (int)quant_min, + (int)quant_max); + + return out; } -Tensor& quantize_per_channel_out(KernelRuntimeContext& context, - const Tensor& input, - const Tensor& scale, - const Tensor& zero_point, - int64_t axis, - int64_t quant_min, - int64_t quant_max, - ScalarType dtype, - Tensor& out) -{ - (void)context; - return quantize_per_channel_out( +Tensor& quantize_per_channel_out( + KernelRuntimeContext& context, + const Tensor& input, + const Tensor& scale, + const Tensor& zero_point, + int64_t axis, + int64_t quant_min, + int64_t quant_max, + ScalarType dtype, + Tensor& out) { + (void)context; + return quantize_per_channel_out( input, scale, zero_point, axis, quant_min, quant_max, dtype, out); } @@ -733,46 +745,45 @@ Tensor& quantize_per_token_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, - Tensor& out) -{ - size_t num_tokens = 1; - for (size_t i = 0; i < input.dim() - 1; i++) - { - num_tokens *= input.size(i); - } - // This unfortunate change is needed because we compile op_quantize for aten - // mode as well + Tensor& out) { + size_t num_tokens = 1; + for (size_t i = 0; i < input.dim() - 1; i++) { + num_tokens *= input.size(i); + } + // This unfortunate change is needed because we compile op_quantize for aten + // mode as well #ifdef USE_ATEN_LIB - std::vector sizes(2); - sizes[0] = num_tokens; - sizes[1] = input.size(input.dim() - 1); - Tensor reshaped_input = at::from_blob( + std::vector sizes(2); + sizes[0] = num_tokens; + sizes[1] = input.size(input.dim() - 1); + Tensor reshaped_input = at::from_blob( input.mutable_data_ptr(), sizes, at::TensorOptions(input.scalar_type())); #else - std::array input_dim_order{0, 1}; - std::array input_sizes; - input_sizes[0] = num_tokens; - input_sizes[1] = input.size(input.dim() - 1); - std::array input_strides; - executorch::runtime::dim_order_to_stride_nocheck( + std::array input_dim_order{0, 1}; + std::array input_sizes; + input_sizes[0] = num_tokens; + input_sizes[1] = input.size(input.dim() - 1); + std::array input_strides; + executorch::runtime::dim_order_to_stride_nocheck( input_sizes.data(), input_dim_order.data(), 2, input_strides.data()); - void* input_data = input.mutable_data_ptr(); - torch::executor::TensorImpl reshaped_input_impl = executorch::runtime::etensor::TensorImpl( - input.scalar_type(), - 2, - input_sizes.data(), - input_data, - input_dim_order.data(), - input_strides.data(), - executorch::runtime::TensorShapeDynamism::STATIC); - Tensor reshaped_input(&reshaped_input_impl); - torch::executor::Error err = resize_tensor(out, input.sizes()); - ET_CHECK_MSG( + void* input_data = input.mutable_data_ptr(); + torch::executor::TensorImpl reshaped_input_impl = + executorch::runtime::etensor::TensorImpl( + input.scalar_type(), + 2, + input_sizes.data(), + input_data, + input_dim_order.data(), + input_strides.data(), + executorch::runtime::TensorShapeDynamism::STATIC); + Tensor reshaped_input(&reshaped_input_impl); + torch::executor::Error err = resize_tensor(out, input.sizes()); + ET_CHECK_MSG( err == torch::executor::Error::Ok, "Failed to resize out Tensor in quantize_per_channel_out"); #endif - return quantize_per_channel_out( + return quantize_per_channel_out( reshaped_input, scale, zero_point, 0, quant_min, quant_max, dtype, out); } @@ -784,14 +795,13 @@ Tensor& quantize_per_token_out( int64_t quant_min, int64_t quant_max, ScalarType dtype, - Tensor& out) -{ - (void)context; - return quantize_per_token_out( + Tensor& out) { + (void)context; + return quantize_per_token_out( input, scale, zero_point, quant_min, quant_max, dtype, out); } } // namespace native -} // namespace G3 +} // namespace FusionG3 } // namespace impl } // namespace cadence \ No newline at end of file diff --git a/backends/cadence/fusion_g3/operators/op_softmax.cpp b/backends/cadence/fusion_g3/operators/op_softmax.cpp index 79ec6dc5d7..c3287643cc 100644 --- a/backends/cadence/fusion_g3/operators/op_softmax.cpp +++ b/backends/cadence/fusion_g3/operators/op_softmax.cpp @@ -6,12 +6,12 @@ * LICENSE file in the root directory of this source tree. */ -#include #include #include #include #include #include +#include using exec_aten::Scalar; using exec_aten::ScalarType; @@ -21,95 +21,92 @@ using torch::executor::KernelRuntimeContext; namespace cadence { namespace impl { -namespace G3 { +namespace G3 { namespace native { Tensor& softmax_out( - KernelRuntimeContext& ctx, - const Tensor& in, - int64_t dim, - bool half_to_float, - Tensor& out) -{ - (void)ctx; + KernelRuntimeContext& ctx, + const Tensor& in, + int64_t dim, + bool half_to_float, + Tensor& out) { + (void)ctx; - ET_KERNEL_CHECK( - ctx, - torch::executor::check_softmax_args(in, dim, half_to_float, out), - InvalidArgument, - out); + ET_KERNEL_CHECK( + ctx, + torch::executor::check_softmax_args(in, dim, half_to_float, out), + InvalidArgument, + out); - ET_KERNEL_CHECK( - ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); - - ET_KERNEL_CHECK( - ctx, executorch::runtime::tensors_have_same_dim_order(in, out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); - // Adjust for negative dim - dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensors_have_same_dim_order(in, out), + InvalidArgument, + out); - int inp_shapes[in.dim()]; - const exec_aten::ArrayRef in_size = in.sizes(); - for(int i = 0; i < in.dim(); i++) - { - inp_shapes[i] = in_size[i]; - } + // Adjust for negative dim + dim = dim < 0 ? dim + executorch::runtime::nonzero_dim(in) : dim; - if(out.scalar_type() == ScalarType::Float) - { - const float * const inp_data = in.const_data_ptr(); - float * const out_data = out.mutable_data_ptr(); - int axis = dim; - xa_nn_softmax_f32_f32(out_data, inp_data, inp_shapes, - in.dim(), &axis); - } - else - { - ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { - const CTYPE* const in_data = in.const_data_ptr(); - CTYPE* const out_data = out.mutable_data_ptr(); + int inp_shapes[in.dim()]; + const exec_aten::ArrayRef in_size = in.sizes(); + for (int i = 0; i < in.dim(); i++) { + inp_shapes[i] = in_size[i]; + } - torch::executor::apply_over_dim( - [in_data, out_data]( - const size_t size, const size_t stride, const size_t base) { - // calculate max in softmax dim. During softmax computation each - // value is subtracted by the maximum in value before calling exp - // to preserve numerical stability. - const CTYPE max_in = torch::executor::apply_unary_reduce_fn( - [](const CTYPE val_in, CTYPE val_accum) { - return std::max(val_in, val_accum); - }, - in_data + base, - size, - stride); + if (out.scalar_type() == ScalarType::Float) { + const float* const inp_data = in.const_data_ptr(); + float* const out_data = out.mutable_data_ptr(); + int axis = dim; + xa_nn_softmax_f32_f32(out_data, inp_data, inp_shapes, in.dim(), &axis); + } else { + ET_SWITCH_FLOATH_TYPES(in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() { + const CTYPE* const in_data = in.const_data_ptr(); + CTYPE* const out_data = out.mutable_data_ptr(); - const CTYPE temp_sum = torch::executor:: - apply_unary_map_reduce_fn( - [max_in](const CTYPE val_in) { - return std::exp(val_in - max_in); - }, - [](const CTYPE mapped_in, CTYPE val_accum) { - return val_accum + mapped_in; + torch::executor::apply_over_dim( + [in_data, out_data]( + const size_t size, const size_t stride, const size_t base) { + // calculate max in softmax dim. During softmax computation each + // value is subtracted by the maximum in value before calling exp + // to preserve numerical stability. + const CTYPE max_in = torch::executor::apply_unary_reduce_fn( + [](const CTYPE val_in, CTYPE val_accum) { + return std::max(val_in, val_accum); }, in_data + base, size, stride); - torch::executor::apply_unary_map_fn( - [max_in, temp_sum](const CTYPE val_in) { - return std::exp(val_in - max_in) / temp_sum; + const CTYPE temp_sum = + torch::executor::apply_unary_map_reduce_fn( + [max_in](const CTYPE val_in) { + return std::exp(val_in - max_in); + }, + [](const CTYPE mapped_in, CTYPE val_accum) { + return val_accum + mapped_in; + }, + in_data + base, + size, + stride); + + torch::executor::apply_unary_map_fn( + [max_in, temp_sum](const CTYPE val_in) { + return std::exp(val_in - max_in) / temp_sum; }, in_data + base, out_data + base, size, stride); - }, - in, - dim); - }); - } + }, + in, + dim); + }); + } - return out; + return out; } } // namespace native