From 2db109a1517b21c1f85512ca78c4cac26b3bbdb5 Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Mon, 22 Apr 2024 08:48:20 -0700 Subject: [PATCH] [TSL] Move away from deprecated typedef PiperOrigin-RevId: 627050864 --- third_party/tsl/tsl/framework/type_traits.h | 2 +- third_party/tsl/tsl/platform/ml_dtypes.h | 2 -- .../evaluator/hlo_evaluator_typed_visitor.h | 2 +- .../hlo_evaluator_typed_visitor_float8.cc | 2 +- xla/literal.cc | 6 ++--- xla/literal_test.cc | 20 ++++++++-------- xla/primitive_util.h | 4 ++-- xla/python/py_values.cc | 4 ++-- xla/service/elemental_ir_emitter.cc | 2 +- xla/tests/constants_test.cc | 5 ++-- xla/tests/convert_test.cc | 23 +++++++++++-------- xla/util.cc | 2 +- xla/util.h | 2 +- xla/util_test.cc | 12 ++++++---- 14 files changed, 46 insertions(+), 42 deletions(-) diff --git a/third_party/tsl/tsl/framework/type_traits.h b/third_party/tsl/tsl/framework/type_traits.h index e96334d0027f2..4b8eed47bde76 100644 --- a/third_party/tsl/tsl/framework/type_traits.h +++ b/third_party/tsl/tsl/framework/type_traits.h @@ -71,7 +71,7 @@ struct is_simple_type { std::is_same::value || std::is_same::value || is_quantized::value || std::is_same::value || std::is_same::value || - std::is_same::value || + std::is_same::value || std::is_same::value || std::is_same::value || std::is_same::value; }; diff --git a/third_party/tsl/tsl/platform/ml_dtypes.h b/third_party/tsl/tsl/platform/ml_dtypes.h index 00ba80538b502..916be8db4f699 100644 --- a/third_party/tsl/tsl/platform/ml_dtypes.h +++ b/third_party/tsl/tsl/platform/ml_dtypes.h @@ -23,8 +23,6 @@ namespace tsl { using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn; using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz; using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz; -using float8_e4m3b11 = float8_e4m3b11fnuz; // Deprecated: old name for - // backward-compatibility only. using float8_e5m2 = ::ml_dtypes::float8_e5m2; using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz; diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h index 68b79d25bf5d2..1ebf786af7bcb 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor.h @@ -1704,7 +1704,7 @@ extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; -extern template class HloEvaluatorTypedVisitor; +extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; extern template class HloEvaluatorTypedVisitor; diff --git a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc index b2cd8eb87292e..9df467e7fd5f6 100644 --- a/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc +++ b/xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc @@ -19,7 +19,7 @@ limitations under the License. namespace xla { template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; -template class HloEvaluatorTypedVisitor; +template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; template class HloEvaluatorTypedVisitor; } // namespace xla diff --git a/xla/literal.cc b/xla/literal.cc index db29b1d94b92a..2fa8b8f2a5663 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -2259,7 +2259,7 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { break; case F8E4M3B11FNUZ: *proto->mutable_f8e4m3b11fnuzs() = std::string( - reinterpret_cast(data().data()), + reinterpret_cast(data().data()), size_bytes_dense()); break; case F8E5M2FNUZ: @@ -2440,8 +2440,8 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) { } case F8E4M3B11FNUZ: { const std::string& s(proto.f8e4m3b11fnuzs()); - TF_RET_CHECK(data().size() * - sizeof(tsl::float8_e4m3b11) == + TF_RET_CHECK(data().size() * + sizeof(tsl::float8_e4m3b11fnuz) == s.size()); memcpy(untyped_data(), s.data(), s.size()); break; diff --git a/xla/literal_test.cc b/xla/literal_test.cc index e216c3a4b7eac..aa628addce25e 100644 --- a/xla/literal_test.cc +++ b/xla/literal_test.cc @@ -175,8 +175,8 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { LiteralUtil::CreateR0(tsl::float8_e4m3fn(0.5)); EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3_lit.ToString()); - auto f8e4m3b11fnuz_lit = - LiteralUtil::CreateR0(tsl::float8_e4m3b11(0.5)); + auto f8e4m3b11fnuz_lit = LiteralUtil::CreateR0( + tsl::float8_e4m3b11fnuz(0.5)); EXPECT_EQ("f8e4m3b11fnuz[] 0.5", f8e4m3b11fnuz_lit.ToString()); auto f8e4m3fnuz_lit = @@ -628,9 +628,9 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(LiteralUtil::CreateR1({r16}).IsAll(8)); EXPECT_TRUE(LiteralUtil::CreateR1({r16}).IsAll(9)); - tsl::float8_e4m3b11 s16(9); // Exactly representable in e4m3 - EXPECT_FALSE(LiteralUtil::CreateR1({s16}).IsAll(8)); - EXPECT_TRUE(LiteralUtil::CreateR1({s16}).IsAll(9)); + tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3 + EXPECT_FALSE(LiteralUtil::CreateR1({s16}).IsAll(8)); + EXPECT_TRUE(LiteralUtil::CreateR1({s16}).IsAll(9)); tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3 EXPECT_FALSE(LiteralUtil::CreateR1({t16}).IsAll(8)); @@ -1195,9 +1195,9 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3) { TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3b11) { Literal output(ShapeUtil::MakeShape(F8E4M3B11FNUZ, {3})); - tsl::float8_e4m3b11 x(0.5f); - output.PopulateWithValue(x); - auto expected = LiteralUtil::CreateR1({x, x, x}); + tsl::float8_e4m3b11fnuz x(0.5f); + output.PopulateWithValue(x); + auto expected = LiteralUtil::CreateR1({x, x, x}); EXPECT_EQ(output, expected); } @@ -1710,7 +1710,7 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) { using e4 = tsl::float8_e4m3fn; auto f8e4m3 = LiteralUtil::CreateR2WithLayout( {{e4{0.}, e4{1.}}, {e4{2.}, e4{3.}}}, layout_r2_dim0major_); - using b11 = tsl::float8_e4m3b11; + using b11 = tsl::float8_e4m3b11fnuz; auto f8e4m3b11 = LiteralUtil::CreateR2WithLayout( {{b11{0.}, b11{1.}}, {b11{2.}, b11{3.}}}, layout_r2_dim0major_); using e5f = tsl::float8_e5m2fnuz; @@ -2198,7 +2198,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) { using e4 = tsl::float8_e4m3fn; auto vector_f8e4m3 = LiteralUtil::CreateR1({e4{10.0}, e4{20.0}, e4{-32.0}}); - using b11 = tsl::float8_e4m3b11; + using b11 = tsl::float8_e4m3b11fnuz; auto vector_f8e4m3b11 = LiteralUtil::CreateR1({b11{10.0}, b11{20.0}, b11{-30.0}}); using e5f = tsl::float8_e5m2fnuz; diff --git a/xla/primitive_util.h b/xla/primitive_util.h index dec252498fc62..31989df8fe084 100644 --- a/xla/primitive_util.h +++ b/xla/primitive_util.h @@ -185,7 +185,7 @@ constexpr PrimitiveType NativeToPrimitiveType() { } template <> -constexpr PrimitiveType NativeToPrimitiveType() { +constexpr PrimitiveType NativeToPrimitiveType() { return F8E4M3B11FNUZ; } @@ -315,7 +315,7 @@ struct PrimitiveTypeToNative { template <> struct PrimitiveTypeToNative { - using type = tsl::float8_e4m3b11; + using type = tsl::float8_e4m3b11fnuz; }; template <> diff --git a/xla/python/py_values.cc b/xla/python/py_values.cc index 0a2df796f4fd6..e487bed3d4986 100644 --- a/xla/python/py_values.cc +++ b/xla/python/py_values.cc @@ -181,7 +181,7 @@ absl::StatusOr HandleNumpyScalar( } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E4M3FN; - } else if (std::is_same()) { + } else if (std::is_same()) { PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>()); type = F8E4M3B11FNUZ; } else if (std::is_same()) { @@ -382,7 +382,7 @@ absl::StatusOr DevicePut(nb::handle arg, (*p)[dtypes.np_float8_e4m3fn.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] = - HandleNumpyScalar; + HandleNumpyScalar; (*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar; (*p)[dtypes.np_float8_e4m3fnuz.ptr()] = HandleNumpyScalar; diff --git a/xla/service/elemental_ir_emitter.cc b/xla/service/elemental_ir_emitter.cc index 2a40e473ab60c..ed4c16fe37cbf 100644 --- a/xla/service/elemental_ir_emitter.cc +++ b/xla/service/elemental_ir_emitter.cc @@ -511,7 +511,7 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value, llvm::Value* max_like_sign = llvm::ConstantFP::get( f16_value->getType(), - static_cast(std::numeric_limits::max())); + static_cast(std::numeric_limits::max())); max_like_sign = b->CreateBitCast(max_like_sign, f16_sign_bit->getType()); max_like_sign = b->CreateOr(max_like_sign, f16_sign_bit); max_like_sign = b->CreateBitCast(max_like_sign, f16_value->getType()); diff --git a/xla/tests/constants_test.cc b/xla/tests/constants_test.cc index 54037278e41dd..a926d24819fd6 100644 --- a/xla/tests/constants_test.cc +++ b/xla/tests/constants_test.cc @@ -122,10 +122,11 @@ TEST_F(ConstantsTest, OneCellF8e5m2) { } TEST_F(ConstantsTest, OneCellF8e4m3b11fnuz) { - std::vector constant = {tsl::float8_e4m3b11{2.0}}; + std::vector constant = { + tsl::float8_e4m3b11fnuz{2.0}}; XlaBuilder builder(TestName()); - auto c = ConstantR1(&builder, constant); + auto c = ConstantR1(&builder, constant); // F8 outputs are not yet supported so convert to F32 ConvertElementType(c, F32); diff --git a/xla/tests/convert_test.cc b/xla/tests/convert_test.cc index 91d545023c637..13ca51a4025eb 100644 --- a/xla/tests/convert_test.cc +++ b/xla/tests/convert_test.cc @@ -960,13 +960,14 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive) { // Convert from FP8 to FP16, then back to FP8 XlaBuilder builder(TestName()); - std::vector all_f8; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast( + static_cast(i))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); + xla::XlaOp all_f8_as_f8 = + ConstantR1(&builder, all_f8); xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16); ConvertElementType(all_f8_as_f16, F8E4M3B11FNUZ); ComputeAndCompare(&builder, {}, ErrorSpec(0.)); @@ -978,8 +979,9 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive2) { std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back(static_cast( - Eigen::numext::bit_cast(static_cast(i)))); + all_f8.push_back( + static_cast(Eigen::numext::bit_cast( + static_cast(i)))); } xla::XlaOp all_f8_as_f32 = ConstantR1(&builder, all_f8); @@ -991,13 +993,14 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive3) { // Convert from FP8 to FP32. XlaBuilder builder(TestName()); - std::vector all_f8; + std::vector all_f8; for (int i = 0; i < 256; i++) { - all_f8.push_back( - Eigen::numext::bit_cast(static_cast(i))); + all_f8.push_back(Eigen::numext::bit_cast( + static_cast(i))); } - xla::XlaOp all_f8_as_f8 = ConstantR1(&builder, all_f8); + xla::XlaOp all_f8_as_f8 = + ConstantR1(&builder, all_f8); ConvertElementType(all_f8_as_f8, F32); ComputeAndCompare(&builder, {}, ErrorSpec(0.)); } diff --git a/xla/util.cc b/xla/util.cc index 342c6ec65d457..52fb915e8dfee 100644 --- a/xla/util.cc +++ b/xla/util.cc @@ -182,7 +182,7 @@ std::string RoundTripFpToString(tsl::float8_e4m3fn value) { return result; } -std::string RoundTripFpToString(tsl::float8_e4m3b11 value) { +std::string RoundTripFpToString(tsl::float8_e4m3b11fnuz value) { std::string result = GenericRoundTripFpToString(value); return result; } diff --git a/xla/util.h b/xla/util.h index 758eb83dc8d9e..7665365408565 100644 --- a/xla/util.h +++ b/xla/util.h @@ -424,7 +424,7 @@ std::string RoundTripFpToString(tsl::float8_e5m2 value); std::string RoundTripFpToString(tsl::float8_e4m3fn value); // Returns a string which can losslessly round trip to a float8 E4M3B11. -std::string RoundTripFpToString(tsl::float8_e4m3b11 value); +std::string RoundTripFpToString(tsl::float8_e4m3b11fnuz value); // Returns a string which can losslessly round trip to a float8 E5M2FNUZ. std::string RoundTripFpToString(tsl::float8_e5m2fnuz value); diff --git a/xla/util_test.cc b/xla/util_test.cc index 205e07febd5ed..052d8cd2fe7ea 100644 --- a/xla/util_test.cc +++ b/xla/util_test.cc @@ -135,7 +135,7 @@ TEST(UtilTest, RoundTripFpToString) { -std::numeric_limits::quiet_NaN()), "-nan"); EXPECT_EQ(RoundTripFpToString( - std::numeric_limits::quiet_NaN()), + std::numeric_limits::quiet_NaN()), "-nan"); EXPECT_EQ(RoundTripFpToString( std::numeric_limits::quiet_NaN()), @@ -249,11 +249,13 @@ TEST(UtilTest, TotalOrder_F8E4M3FN) { TEST(UtilTest, TotalOrder_F8E4M3B11) { for (int a = 0; a < 256; ++a) { - tsl::float8_e4m3b11 x = - Eigen::numext::bit_cast(static_cast(a)); + tsl::float8_e4m3b11fnuz x = + Eigen::numext::bit_cast( + static_cast(a)); for (int b = 0; b < 256; ++b) { - tsl::float8_e4m3b11 y = - Eigen::numext::bit_cast(static_cast(b)); + tsl::float8_e4m3b11fnuz y = + Eigen::numext::bit_cast( + static_cast(b)); TotalOrderHelper(x, y); } }