From 2c65b977d24b5d713f16c224103b880ceb39c5e3 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 21 Nov 2024 14:14:51 -0800 Subject: [PATCH] Take advantage of C++17 in scalar_type_util.h (#7022) Pull Request resolved: https://github.com/pytorch/executorch/pull/6968 I generated a big ugly table because we couldn't make promoteTypes constexpr before we had C++17. Now we have C++17. ghstack-source-id: 254805946 Differential Revision: [D66181946](https://our.internmc.facebook.com/intern/diff/D66181946/) Co-authored-by: Scott Wolchok --- .../core/exec_aten/util/genScalarTypeTable.py | 39 -- .../core/exec_aten/util/scalar_type_util.h | 387 +++++------------- 2 files changed, 108 insertions(+), 318 deletions(-) delete mode 100644 runtime/core/exec_aten/util/genScalarTypeTable.py diff --git a/runtime/core/exec_aten/util/genScalarTypeTable.py b/runtime/core/exec_aten/util/genScalarTypeTable.py deleted file mode 100644 index c2bc84c270..0000000000 --- a/runtime/core/exec_aten/util/genScalarTypeTable.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -indexToType = [ - "U1", - "I1", - "I2", - "I4", - "I8", - "F2", - "F4", - "F8", - "C2", - "C4", - "C8", - "B1", - "BF", -] -promoteTypesLookup = [ - ["U1", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "U1", "BF"], - ["I2", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I1", "BF"], - ["I2", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I2", "BF"], - ["I4", "I4", "I4", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I4", "BF"], - ["I8", "I8", "I8", "I8", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I8", "BF"], - ["F2", "F2", "F2", "F2", "F2", "F2", "F4", "F8", "C2", "C4", "C8", "F2", "F4"], - ["F4", "F4", "F4", "F4", "F4", "F4", "F4", "F8", "C4", "C4", "C8", "F4", "F4"], - ["F8", "F8", "F8", "F8", "F8", "F8", "F8", "F8", "C8", "C8", "C8", "F8", "F8"], - ["C2", "C2", "C2", "C2", "C2", "C2", "C4", "C8", "C2", "C4", "C8", "C2", "C4"], - ["C4", "C4", "C4", "C4", "C4", "C4", "C4", "C8", "C4", "C4", "C8", "C4", "C4"], - ["C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8"], - ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1", "BF"], - ["BF", "BF", "BF", "BF", "BF", "F4", "F4", "F8", "C4", "C4", "C8", "BF", "BF"], -] -for rowIndex, row in enumerate(promoteTypesLookup): - for colIndex, col in enumerate(row): - print(f"TABLE_ENTRY({indexToType[rowIndex]}, {indexToType[colIndex]}, {col});") diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 3f186c3c64..668c3c1cac 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -69,8 +69,8 @@ struct is_floating_point : std::integral_constant< bool, std::is_floating_point::value || - std::is_same::value || - std::is_same::value> {}; + std::is_same_v || + std::is_same_v> {}; // Util to figure out if the scalar type is one of the // reduced precision floating point types. @@ -78,8 +78,8 @@ template struct is_reduced_floating_point : std::integral_constant< bool, - std::is_same::value || - std::is_same::value> {}; + std::is_same_v || + std::is_same_v> {}; template constexpr bool is_reduced_floating_point_v = @@ -662,9 +662,9 @@ struct can_cast : std::integral_constant< template < typename To, typename From, - typename std::enable_if< + std::enable_if_t< (std::is_floating_point::value && std::is_integral::value), - int>::type = 0> + int> = 0> To convert(From val) { return static_cast(static_cast(val)); } @@ -672,22 +672,28 @@ To convert(From val) { template < typename To, typename From, - typename std::enable_if< + std::enable_if_t< !(std::is_floating_point::value && std::is_integral::value), - int>::type = 0> + int> = 0> To convert(From val) { return static_cast(val); } namespace internal { - -template -struct promote_types_lookup; - -template -struct promote_types_lookup { - using type = T1; -}; +// This is generated according to NumPy's promote_types +inline constexpr auto u1 = ::executorch::aten::ScalarType::Byte; +inline constexpr auto i1 = ::executorch::aten::ScalarType::Char; +inline constexpr auto i2 = ::executorch::aten::ScalarType::Short; +inline constexpr auto i4 = ::executorch::aten::ScalarType::Int; +inline constexpr auto i8 = ::executorch::aten::ScalarType::Long; +inline constexpr auto f2 = ::executorch::aten::ScalarType::Half; +inline constexpr auto f4 = ::executorch::aten::ScalarType::Float; +inline constexpr auto f8 = ::executorch::aten::ScalarType::Double; +inline constexpr auto c2 = ::executorch::aten::ScalarType::ComplexHalf; +inline constexpr auto c4 = ::executorch::aten::ScalarType::ComplexFloat; +inline constexpr auto c8 = ::executorch::aten::ScalarType::ComplexDouble; +inline constexpr auto b1 = ::executorch::aten::ScalarType::Bool; +inline constexpr auto bf = ::executorch::aten::ScalarType::BFloat16; using U1 = typename ScalarTypeToCppType<::executorch::aten::ScalarType::Byte>::type; @@ -716,253 +722,62 @@ using B1 = using BF = typename ScalarTypeToCppType< ::executorch::aten::ScalarType::BFloat16>::type; -#define TABLE_ENTRY(key1, key2, value) \ - template <> \ - struct promote_types_lookup { \ - using type = value; \ +inline constexpr std::array<::executorch::aten::ScalarType, 13> index2dtype = { + {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}}; + +constexpr std::array< + int64_t, + static_cast(::executorch::aten::ScalarType::NumOptions)> +calculate_dtype2index() { + std::array< + int64_t, + static_cast(::executorch::aten::ScalarType::NumOptions)> + inverse = {}; + for (int64_t i = 0; + i < static_cast(::executorch::aten::ScalarType::NumOptions); + i++) { + inverse[i] = -1; + } + for (int64_t i = 0; i < static_cast(index2dtype.size()); i++) { + inverse[static_cast(index2dtype[i])] = i; } + return inverse; +} -/* promote_types_lookup is a compile-time-accessible version of the - * table in promoteTypes below; we cannot make promoteTypes constexpr - * and use it directly because we are on C++11 and thus don't have - * C++17 relaxed constexpr. The below series of entries is generated - * by genScalarTypeTable.py. */ -TABLE_ENTRY(U1, U1, U1); -TABLE_ENTRY(U1, I1, I2); -TABLE_ENTRY(U1, I2, I2); -TABLE_ENTRY(U1, I4, I4); -TABLE_ENTRY(U1, I8, I8); -TABLE_ENTRY(U1, F2, F2); -TABLE_ENTRY(U1, F4, F4); -TABLE_ENTRY(U1, F8, F8); -TABLE_ENTRY(U1, C2, C2); -TABLE_ENTRY(U1, C4, C4); -TABLE_ENTRY(U1, C8, C8); -TABLE_ENTRY(U1, B1, U1); -TABLE_ENTRY(U1, BF, BF); -TABLE_ENTRY(I1, U1, I2); -TABLE_ENTRY(I1, I1, I1); -TABLE_ENTRY(I1, I2, I2); -TABLE_ENTRY(I1, I4, I4); -TABLE_ENTRY(I1, I8, I8); -TABLE_ENTRY(I1, F2, F2); -TABLE_ENTRY(I1, F4, F4); -TABLE_ENTRY(I1, F8, F8); -TABLE_ENTRY(I1, C2, C2); -TABLE_ENTRY(I1, C4, C4); -TABLE_ENTRY(I1, C8, C8); -TABLE_ENTRY(I1, B1, I1); -TABLE_ENTRY(I1, BF, BF); -TABLE_ENTRY(I2, U1, I2); -TABLE_ENTRY(I2, I1, I2); -TABLE_ENTRY(I2, I2, I2); -TABLE_ENTRY(I2, I4, I4); -TABLE_ENTRY(I2, I8, I8); -TABLE_ENTRY(I2, F2, F2); -TABLE_ENTRY(I2, F4, F4); -TABLE_ENTRY(I2, F8, F8); -TABLE_ENTRY(I2, C2, C2); -TABLE_ENTRY(I2, C4, C4); -TABLE_ENTRY(I2, C8, C8); -TABLE_ENTRY(I2, B1, I2); -TABLE_ENTRY(I2, BF, BF); -TABLE_ENTRY(I4, U1, I4); -TABLE_ENTRY(I4, I1, I4); -TABLE_ENTRY(I4, I2, I4); -TABLE_ENTRY(I4, I4, I4); -TABLE_ENTRY(I4, I8, I8); -TABLE_ENTRY(I4, F2, F2); -TABLE_ENTRY(I4, F4, F4); -TABLE_ENTRY(I4, F8, F8); -TABLE_ENTRY(I4, C2, C2); -TABLE_ENTRY(I4, C4, C4); -TABLE_ENTRY(I4, C8, C8); -TABLE_ENTRY(I4, B1, I4); -TABLE_ENTRY(I4, BF, BF); -TABLE_ENTRY(I8, U1, I8); -TABLE_ENTRY(I8, I1, I8); -TABLE_ENTRY(I8, I2, I8); -TABLE_ENTRY(I8, I4, I8); -TABLE_ENTRY(I8, I8, I8); -TABLE_ENTRY(I8, F2, F2); -TABLE_ENTRY(I8, F4, F4); -TABLE_ENTRY(I8, F8, F8); -TABLE_ENTRY(I8, C2, C2); -TABLE_ENTRY(I8, C4, C4); -TABLE_ENTRY(I8, C8, C8); -TABLE_ENTRY(I8, B1, I8); -TABLE_ENTRY(I8, BF, BF); -TABLE_ENTRY(F2, U1, F2); -TABLE_ENTRY(F2, I1, F2); -TABLE_ENTRY(F2, I2, F2); -TABLE_ENTRY(F2, I4, F2); -TABLE_ENTRY(F2, I8, F2); -TABLE_ENTRY(F2, F2, F2); -TABLE_ENTRY(F2, F4, F4); -TABLE_ENTRY(F2, F8, F8); -TABLE_ENTRY(F2, C2, C2); -TABLE_ENTRY(F2, C4, C4); -TABLE_ENTRY(F2, C8, C8); -TABLE_ENTRY(F2, B1, F2); -TABLE_ENTRY(F2, BF, F4); -TABLE_ENTRY(F4, U1, F4); -TABLE_ENTRY(F4, I1, F4); -TABLE_ENTRY(F4, I2, F4); -TABLE_ENTRY(F4, I4, F4); -TABLE_ENTRY(F4, I8, F4); -TABLE_ENTRY(F4, F2, F4); -TABLE_ENTRY(F4, F4, F4); -TABLE_ENTRY(F4, F8, F8); -TABLE_ENTRY(F4, C2, C4); -TABLE_ENTRY(F4, C4, C4); -TABLE_ENTRY(F4, C8, C8); -TABLE_ENTRY(F4, B1, F4); -TABLE_ENTRY(F4, BF, F4); -TABLE_ENTRY(F8, U1, F8); -TABLE_ENTRY(F8, I1, F8); -TABLE_ENTRY(F8, I2, F8); -TABLE_ENTRY(F8, I4, F8); -TABLE_ENTRY(F8, I8, F8); -TABLE_ENTRY(F8, F2, F8); -TABLE_ENTRY(F8, F4, F8); -TABLE_ENTRY(F8, F8, F8); -TABLE_ENTRY(F8, C2, C8); -TABLE_ENTRY(F8, C4, C8); -TABLE_ENTRY(F8, C8, C8); -TABLE_ENTRY(F8, B1, F8); -TABLE_ENTRY(F8, BF, F8); -TABLE_ENTRY(C2, U1, C2); -TABLE_ENTRY(C2, I1, C2); -TABLE_ENTRY(C2, I2, C2); -TABLE_ENTRY(C2, I4, C2); -TABLE_ENTRY(C2, I8, C2); -TABLE_ENTRY(C2, F2, C2); -TABLE_ENTRY(C2, F4, C4); -TABLE_ENTRY(C2, F8, C8); -TABLE_ENTRY(C2, C2, C2); -TABLE_ENTRY(C2, C4, C4); -TABLE_ENTRY(C2, C8, C8); -TABLE_ENTRY(C2, B1, C2); -TABLE_ENTRY(C2, BF, C4); -TABLE_ENTRY(C4, U1, C4); -TABLE_ENTRY(C4, I1, C4); -TABLE_ENTRY(C4, I2, C4); -TABLE_ENTRY(C4, I4, C4); -TABLE_ENTRY(C4, I8, C4); -TABLE_ENTRY(C4, F2, C4); -TABLE_ENTRY(C4, F4, C4); -TABLE_ENTRY(C4, F8, C8); -TABLE_ENTRY(C4, C2, C4); -TABLE_ENTRY(C4, C4, C4); -TABLE_ENTRY(C4, C8, C8); -TABLE_ENTRY(C4, B1, C4); -TABLE_ENTRY(C4, BF, C4); -TABLE_ENTRY(C8, U1, C8); -TABLE_ENTRY(C8, I1, C8); -TABLE_ENTRY(C8, I2, C8); -TABLE_ENTRY(C8, I4, C8); -TABLE_ENTRY(C8, I8, C8); -TABLE_ENTRY(C8, F2, C8); -TABLE_ENTRY(C8, F4, C8); -TABLE_ENTRY(C8, F8, C8); -TABLE_ENTRY(C8, C2, C8); -TABLE_ENTRY(C8, C4, C8); -TABLE_ENTRY(C8, C8, C8); -TABLE_ENTRY(C8, B1, C8); -TABLE_ENTRY(C8, BF, C8); -TABLE_ENTRY(B1, U1, U1); -TABLE_ENTRY(B1, I1, I1); -TABLE_ENTRY(B1, I2, I2); -TABLE_ENTRY(B1, I4, I4); -TABLE_ENTRY(B1, I8, I8); -TABLE_ENTRY(B1, F2, F2); -TABLE_ENTRY(B1, F4, F4); -TABLE_ENTRY(B1, F8, F8); -TABLE_ENTRY(B1, C2, C2); -TABLE_ENTRY(B1, C4, C4); -TABLE_ENTRY(B1, C8, C8); -TABLE_ENTRY(B1, B1, B1); -TABLE_ENTRY(B1, BF, BF); -TABLE_ENTRY(BF, U1, BF); -TABLE_ENTRY(BF, I1, BF); -TABLE_ENTRY(BF, I2, BF); -TABLE_ENTRY(BF, I4, BF); -TABLE_ENTRY(BF, I8, BF); -TABLE_ENTRY(BF, F2, F4); -TABLE_ENTRY(BF, F4, F4); -TABLE_ENTRY(BF, F8, F8); -TABLE_ENTRY(BF, C2, C4); -TABLE_ENTRY(BF, C4, C4); -TABLE_ENTRY(BF, C8, C8); -TABLE_ENTRY(BF, B1, BF); -TABLE_ENTRY(BF, BF, BF); +inline constexpr auto dtype2index = calculate_dtype2index(); +inline constexpr int NUM_PROMOTE_TYPES = 13; +// Should match _promoteTypesLookup in c10/core/ScalarType.cpp so that +// we match PyTorch core type promotion semantics. +inline constexpr ::executorch::aten::ScalarType + promoteTypesLookup[NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = { + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8}, + /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf}, +}; } // namespace internal -template -struct promote_types { - private: - static_assert( - std::is_same::value || - (!is_qint_type::value && !is_qint_type::value), - "promote_types not valid for quantized dtypes"); - static_assert( - std::is_same::value || - (!is_bits_type::value && !is_bits_type::value), - "promote_types not valid for bits dtypes"); - static_assert( - std::is_same::value || - (!is_float8_type::value && !is_float8_type::value), - "promote_types not valid for float8 dtypes"); - static_assert( - std::is_same::value || - (!is_barebones_unsigned_type::value && - !is_barebones_unsigned_type::value), - "promote_types not valid for barebones unsigned dtypes"); - - using promoted_type_not_respecting_half_to_float = - typename internal::promote_types_lookup::type; - - public: - using type = typename std::conditional< - half_to_float && - (std::is_same< - promoted_type_not_respecting_half_to_float, - typename ScalarTypeToCppType< - ::executorch::aten::ScalarType::Half>::type>::value || - std::is_same< - promoted_type_not_respecting_half_to_float, - typename ScalarTypeToCppType< - ::executorch::aten::ScalarType::BFloat16>::type>::value), - typename ScalarTypeToCppType<::executorch::aten::ScalarType::Float>::type, - promoted_type_not_respecting_half_to_float>::type; -}; - /** * Implements type promotion rules that are consistent with ATen behaviour, * which in turn is consistent with NumPy's promote_types. * If half_to_float is set to true, then half and bfloat16 will be promoted to * float instead */ -inline ::executorch::aten::ScalarType promoteTypes( +inline constexpr ::executorch::aten::ScalarType promoteTypes( ::executorch::aten::ScalarType a, ::executorch::aten::ScalarType b, bool half_to_float = false) { - // This is generated according to NumPy's promote_types - constexpr auto u1 = ::executorch::aten::ScalarType::Byte; - constexpr auto i1 = ::executorch::aten::ScalarType::Char; - constexpr auto i2 = ::executorch::aten::ScalarType::Short; - constexpr auto i4 = ::executorch::aten::ScalarType::Int; - constexpr auto i8 = ::executorch::aten::ScalarType::Long; - constexpr auto f2 = ::executorch::aten::ScalarType::Half; - constexpr auto f4 = ::executorch::aten::ScalarType::Float; - constexpr auto f8 = ::executorch::aten::ScalarType::Double; - constexpr auto c2 = ::executorch::aten::ScalarType::ComplexHalf; - constexpr auto c4 = ::executorch::aten::ScalarType::ComplexFloat; - constexpr auto c8 = ::executorch::aten::ScalarType::ComplexDouble; - constexpr auto b1 = ::executorch::aten::ScalarType::Bool; - constexpr auto bf = ::executorch::aten::ScalarType::BFloat16; - // For QInt types, only allow exact match if (::executorch::runtime::isQIntType(a) && a == b) { return a; @@ -999,39 +814,12 @@ inline ::executorch::aten::ScalarType promoteTypes( ET_CHECK_MSG(false, "promoteTypes not valid for barebone unsigned dtypes"); } - // 12 types are handled by this function, see the constexpr definitions above - const int NUM_PROMOTE_TYPES = 13; - - static constexpr std:: - array - dtype2index = {{ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - -1, -1, -1, 12, -1, -1, -1, -1, -1, -1, -1, -1, - }}; - auto ix_a = dtype2index[(int)a]; + auto ix_a = ::executorch::runtime::internal::dtype2index[(int)a]; ET_CHECK(ix_a != -1); - auto ix_b = dtype2index[(int)b]; + auto ix_b = ::executorch::runtime::internal::dtype2index[(int)b]; ET_CHECK(ix_b != -1); - static constexpr ::executorch::aten::ScalarType - _promoteTypesLookup[NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = { - /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/ - /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf}, - /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf}, - /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf}, - /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf}, - /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf}, - /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4}, - /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8}, - /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4}, - /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4}, - /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, - /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}, - /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf}, - }; - ::executorch::aten::ScalarType promoted_type = - _promoteTypesLookup[ix_a][ix_b]; + ::executorch::runtime::internal::promoteTypesLookup[ix_a][ix_b]; if (half_to_float && (promoted_type == ::executorch::aten::ScalarType::Half || @@ -1042,6 +830,47 @@ inline ::executorch::aten::ScalarType promoteTypes( return promoted_type; } +template +struct promote_types { + private: + static_assert( + std::is_same_v || + (!is_qint_type::value && !is_qint_type::value), + "promote_types not valid for quantized dtypes"); + static_assert( + std::is_same_v || + (!is_bits_type::value && !is_bits_type::value), + "promote_types not valid for bits dtypes"); + static_assert( + std::is_same_v || + (!is_float8_type::value && !is_float8_type::value), + "promote_types not valid for float8 dtypes"); + static_assert( + std::is_same_v || + (!is_barebones_unsigned_type::value && + !is_barebones_unsigned_type::value), + "promote_types not valid for barebones unsigned dtypes"); + + using promoted_type_not_respecting_half_to_float = + typename ScalarTypeToCppType::value, + CppTypeToScalarType::value)>::type; + + public: + using type = std::conditional_t< + half_to_float && + (std::is_same_v< + promoted_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + ::executorch::aten::ScalarType::Half>::type> || + std::is_same_v< + promoted_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + ::executorch::aten::ScalarType::BFloat16>::type>), + typename ScalarTypeToCppType<::executorch::aten::ScalarType::Float>::type, + promoted_type_not_respecting_half_to_float>; +}; + // // Helper macros for switch case macros (see below) //