Skip to content

Commit

Permalink
[TSL] Move away from deprecated typedef
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627050864
  • Loading branch information
majnemer authored and copybara-github committed Apr 22, 2024
1 parent c00dde1 commit 2db109a
Show file tree
Hide file tree
Showing 14 changed files with 46 additions and 42 deletions.
2 changes: 1 addition & 1 deletion third_party/tsl/tsl/framework/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct is_simple_type {
std::is_same<T, complex64>::value || std::is_same<T, complex128>::value ||
is_quantized<T>::value || std::is_same<T, bfloat16>::value ||
std::is_same<T, float8_e4m3fn>::value ||
std::is_same<T, float8_e4m3b11>::value ||
std::is_same<T, float8_e4m3b11fnuz>::value ||
std::is_same<T, float8_e5m2>::value || std::is_same<T, int4>::value ||
std::is_same<T, uint4>::value;
};
Expand Down
2 changes: 0 additions & 2 deletions third_party/tsl/tsl/platform/ml_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ namespace tsl {
using float8_e4m3fn = ::ml_dtypes::float8_e4m3fn;
using float8_e4m3fnuz = ::ml_dtypes::float8_e4m3fnuz;
using float8_e4m3b11fnuz = ::ml_dtypes::float8_e4m3b11fnuz;
using float8_e4m3b11 = float8_e4m3b11fnuz; // Deprecated: old name for
// backward-compatibility only.
using float8_e5m2 = ::ml_dtypes::float8_e5m2;
using float8_e5m2fnuz = ::ml_dtypes::float8_e5m2fnuz;

Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/evaluator/hlo_evaluator_typed_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1704,7 +1704,7 @@ extern template class HloEvaluatorTypedVisitor<complex128>;
extern template class HloEvaluatorTypedVisitor<bfloat16, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
extern template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;

Expand Down
2 changes: 1 addition & 1 deletion xla/hlo/evaluator/hlo_evaluator_typed_visitor_float8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
namespace xla {
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fn, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3b11fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e5m2fnuz, float>;
template class HloEvaluatorTypedVisitor<tsl::float8_e4m3fnuz, float>;
} // namespace xla
6 changes: 3 additions & 3 deletions xla/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2259,7 +2259,7 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
break;
case F8E4M3B11FNUZ:
*proto->mutable_f8e4m3b11fnuzs() = std::string(
reinterpret_cast<const char*>(data<tsl::float8_e4m3b11>().data()),
reinterpret_cast<const char*>(data<tsl::float8_e4m3b11fnuz>().data()),
size_bytes_dense());
break;
case F8E5M2FNUZ:
Expand Down Expand Up @@ -2440,8 +2440,8 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
}
case F8E4M3B11FNUZ: {
const std::string& s(proto.f8e4m3b11fnuzs());
TF_RET_CHECK(data<tsl::float8_e4m3b11>().size() *
sizeof(tsl::float8_e4m3b11) ==
TF_RET_CHECK(data<tsl::float8_e4m3b11fnuz>().size() *
sizeof(tsl::float8_e4m3b11fnuz) ==
s.size());
memcpy(untyped_data(), s.data(), s.size());
break;
Expand Down
20 changes: 10 additions & 10 deletions xla/literal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
LiteralUtil::CreateR0<tsl::float8_e4m3fn>(tsl::float8_e4m3fn(0.5));
EXPECT_EQ("f8e4m3fn[] 0.5", f8e4m3_lit.ToString());

auto f8e4m3b11fnuz_lit =
LiteralUtil::CreateR0<tsl::float8_e4m3b11>(tsl::float8_e4m3b11(0.5));
auto f8e4m3b11fnuz_lit = LiteralUtil::CreateR0<tsl::float8_e4m3b11fnuz>(
tsl::float8_e4m3b11fnuz(0.5));
EXPECT_EQ("f8e4m3b11fnuz[] 0.5", f8e4m3b11fnuz_lit.ToString());

auto f8e4m3fnuz_lit =
Expand Down Expand Up @@ -628,9 +628,9 @@ TEST_F(LiteralUtilTest, IsAll) {
EXPECT_FALSE(LiteralUtil::CreateR1<tsl::float8_e4m3fn>({r16}).IsAll(8));
EXPECT_TRUE(LiteralUtil::CreateR1<tsl::float8_e4m3fn>({r16}).IsAll(9));

tsl::float8_e4m3b11 s16(9); // Exactly representable in e4m3
EXPECT_FALSE(LiteralUtil::CreateR1<tsl::float8_e4m3b11>({s16}).IsAll(8));
EXPECT_TRUE(LiteralUtil::CreateR1<tsl::float8_e4m3b11>({s16}).IsAll(9));
tsl::float8_e4m3b11fnuz s16(9); // Exactly representable in e4m3
EXPECT_FALSE(LiteralUtil::CreateR1<tsl::float8_e4m3b11fnuz>({s16}).IsAll(8));
EXPECT_TRUE(LiteralUtil::CreateR1<tsl::float8_e4m3b11fnuz>({s16}).IsAll(9));

tsl::float8_e4m3fnuz t16(9); // Exactly representable in e4m3
EXPECT_FALSE(LiteralUtil::CreateR1<tsl::float8_e4m3fnuz>({t16}).IsAll(8));
Expand Down Expand Up @@ -1195,9 +1195,9 @@ TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3) {

TEST_F(LiteralUtilTest, PopulateWithValueR1F8e4m3b11) {
Literal output(ShapeUtil::MakeShape(F8E4M3B11FNUZ, {3}));
tsl::float8_e4m3b11 x(0.5f);
output.PopulateWithValue<tsl::float8_e4m3b11>(x);
auto expected = LiteralUtil::CreateR1<tsl::float8_e4m3b11>({x, x, x});
tsl::float8_e4m3b11fnuz x(0.5f);
output.PopulateWithValue<tsl::float8_e4m3b11fnuz>(x);
auto expected = LiteralUtil::CreateR1<tsl::float8_e4m3b11fnuz>({x, x, x});
EXPECT_EQ(output, expected);
}

Expand Down Expand Up @@ -1710,7 +1710,7 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatchF8) {
using e4 = tsl::float8_e4m3fn;
auto f8e4m3 = LiteralUtil::CreateR2WithLayout<e4>(
{{e4{0.}, e4{1.}}, {e4{2.}, e4{3.}}}, layout_r2_dim0major_);
using b11 = tsl::float8_e4m3b11;
using b11 = tsl::float8_e4m3b11fnuz;
auto f8e4m3b11 = LiteralUtil::CreateR2WithLayout<b11>(
{{b11{0.}, b11{1.}}, {b11{2.}, b11{3.}}}, layout_r2_dim0major_);
using e5f = tsl::float8_e5m2fnuz;
Expand Down Expand Up @@ -2198,7 +2198,7 @@ TEST_F(LiteralUtilTest, ProtoRoundTrip) {
using e4 = tsl::float8_e4m3fn;
auto vector_f8e4m3 =
LiteralUtil::CreateR1<e4>({e4{10.0}, e4{20.0}, e4{-32.0}});
using b11 = tsl::float8_e4m3b11;
using b11 = tsl::float8_e4m3b11fnuz;
auto vector_f8e4m3b11 =
LiteralUtil::CreateR1<b11>({b11{10.0}, b11{20.0}, b11{-30.0}});
using e5f = tsl::float8_e5m2fnuz;
Expand Down
4 changes: 2 additions & 2 deletions xla/primitive_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3fn>() {
}

template <>
constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3b11>() {
constexpr PrimitiveType NativeToPrimitiveType<tsl::float8_e4m3b11fnuz>() {
return F8E4M3B11FNUZ;
}

Expand Down Expand Up @@ -315,7 +315,7 @@ struct PrimitiveTypeToNative<F8E4M3FN> {

template <>
struct PrimitiveTypeToNative<F8E4M3B11FNUZ> {
using type = tsl::float8_e4m3b11;
using type = tsl::float8_e4m3b11fnuz;
};

template <>
Expand Down
4 changes: 2 additions & 2 deletions xla/python/py_values.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ absl::StatusOr<DevicePutResultFn> HandleNumpyScalar(
} else if (std::is_same<T, tsl::float8_e4m3fn>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E4M3FN;
} else if (std::is_same<T, tsl::float8_e4m3b11>()) {
} else if (std::is_same<T, tsl::float8_e4m3b11fnuz>()) {
PyArray_ScalarAsCtype(h.ptr(), &data.template emplace<2>());
type = F8E4M3B11FNUZ;
} else if (std::is_same<T, tsl::float8_e5m2>()) {
Expand Down Expand Up @@ -382,7 +382,7 @@ absl::StatusOr<DevicePutResultFn> DevicePut(nb::handle arg,
(*p)[dtypes.np_float8_e4m3fn.ptr()] =
HandleNumpyScalar<tsl::float8_e4m3fn>;
(*p)[dtypes.np_float8_e4m3b11fnuz.ptr()] =
HandleNumpyScalar<tsl::float8_e4m3b11>;
HandleNumpyScalar<tsl::float8_e4m3b11fnuz>;
(*p)[dtypes.np_float8_e5m2.ptr()] = HandleNumpyScalar<tsl::float8_e5m2>;
(*p)[dtypes.np_float8_e4m3fnuz.ptr()] =
HandleNumpyScalar<tsl::float8_e4m3fnuz>;
Expand Down
2 changes: 1 addition & 1 deletion xla/service/elemental_ir_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ llvm::Value* EmitF8e4m3b11fnuzToF16(llvm::Value* f8_value,

llvm::Value* max_like_sign = llvm::ConstantFP::get(
f16_value->getType(),
static_cast<float>(std::numeric_limits<tsl::float8_e4m3b11>::max()));
static_cast<float>(std::numeric_limits<tsl::float8_e4m3b11fnuz>::max()));
max_like_sign = b->CreateBitCast(max_like_sign, f16_sign_bit->getType());
max_like_sign = b->CreateOr(max_like_sign, f16_sign_bit);
max_like_sign = b->CreateBitCast(max_like_sign, f16_value->getType());
Expand Down
5 changes: 3 additions & 2 deletions xla/tests/constants_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,11 @@ TEST_F(ConstantsTest, OneCellF8e5m2) {
}

TEST_F(ConstantsTest, OneCellF8e4m3b11fnuz) {
std::vector<tsl::float8_e4m3b11> constant = {tsl::float8_e4m3b11{2.0}};
std::vector<tsl::float8_e4m3b11fnuz> constant = {
tsl::float8_e4m3b11fnuz{2.0}};

XlaBuilder builder(TestName());
auto c = ConstantR1<tsl::float8_e4m3b11>(&builder, constant);
auto c = ConstantR1<tsl::float8_e4m3b11fnuz>(&builder, constant);
// F8 outputs are not yet supported so convert to F32
ConvertElementType(c, F32);

Expand Down
23 changes: 13 additions & 10 deletions xla/tests/convert_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,13 +960,14 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive) {
// Convert from FP8 to FP16, then back to FP8
XlaBuilder builder(TestName());

std::vector<tsl::float8_e4m3b11> all_f8;
std::vector<tsl::float8_e4m3b11fnuz> all_f8;
for (int i = 0; i < 256; i++) {
all_f8.push_back(
Eigen::numext::bit_cast<tsl::float8_e4m3b11>(static_cast<uint8_t>(i)));
all_f8.push_back(Eigen::numext::bit_cast<tsl::float8_e4m3b11fnuz>(
static_cast<uint8_t>(i)));
}

xla::XlaOp all_f8_as_f8 = ConstantR1<tsl::float8_e4m3b11>(&builder, all_f8);
xla::XlaOp all_f8_as_f8 =
ConstantR1<tsl::float8_e4m3b11fnuz>(&builder, all_f8);
xla::XlaOp all_f8_as_f16 = ConvertElementType(all_f8_as_f8, F16);
ConvertElementType(all_f8_as_f16, F8E4M3B11FNUZ);
ComputeAndCompare(&builder, {}, ErrorSpec(0.));
Expand All @@ -978,8 +979,9 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive2) {

std::vector<float> all_f8;
for (int i = 0; i < 256; i++) {
all_f8.push_back(static_cast<float>(
Eigen::numext::bit_cast<tsl::float8_e4m3b11>(static_cast<uint8_t>(i))));
all_f8.push_back(
static_cast<float>(Eigen::numext::bit_cast<tsl::float8_e4m3b11fnuz>(
static_cast<uint8_t>(i))));
}

xla::XlaOp all_f8_as_f32 = ConstantR1<float>(&builder, all_f8);
Expand All @@ -991,13 +993,14 @@ XLA_TEST_F(ConvertTest, ConvertF8e4m3b11fnuzF16RoundtripExhaustive3) {
// Convert from FP8 to FP32.
XlaBuilder builder(TestName());

std::vector<tsl::float8_e4m3b11> all_f8;
std::vector<tsl::float8_e4m3b11fnuz> all_f8;
for (int i = 0; i < 256; i++) {
all_f8.push_back(
Eigen::numext::bit_cast<tsl::float8_e4m3b11>(static_cast<uint8_t>(i)));
all_f8.push_back(Eigen::numext::bit_cast<tsl::float8_e4m3b11fnuz>(
static_cast<uint8_t>(i)));
}

xla::XlaOp all_f8_as_f8 = ConstantR1<tsl::float8_e4m3b11>(&builder, all_f8);
xla::XlaOp all_f8_as_f8 =
ConstantR1<tsl::float8_e4m3b11fnuz>(&builder, all_f8);
ConvertElementType(all_f8_as_f8, F32);
ComputeAndCompare(&builder, {}, ErrorSpec(0.));
}
Expand Down
2 changes: 1 addition & 1 deletion xla/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ std::string RoundTripFpToString(tsl::float8_e4m3fn value) {
return result;
}

std::string RoundTripFpToString(tsl::float8_e4m3b11 value) {
std::string RoundTripFpToString(tsl::float8_e4m3b11fnuz value) {
std::string result = GenericRoundTripFpToString(value);
return result;
}
Expand Down
2 changes: 1 addition & 1 deletion xla/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ std::string RoundTripFpToString(tsl::float8_e5m2 value);
std::string RoundTripFpToString(tsl::float8_e4m3fn value);

// Returns a string which can losslessly round trip to a float8 E4M3B11.
std::string RoundTripFpToString(tsl::float8_e4m3b11 value);
std::string RoundTripFpToString(tsl::float8_e4m3b11fnuz value);

// Returns a string which can losslessly round trip to a float8 E5M2FNUZ.
std::string RoundTripFpToString(tsl::float8_e5m2fnuz value);
Expand Down
12 changes: 7 additions & 5 deletions xla/util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ TEST(UtilTest, RoundTripFpToString) {
-std::numeric_limits<tsl::float8_e4m3fn>::quiet_NaN()),
"-nan");
EXPECT_EQ(RoundTripFpToString(
std::numeric_limits<tsl::float8_e4m3b11>::quiet_NaN()),
std::numeric_limits<tsl::float8_e4m3b11fnuz>::quiet_NaN()),
"-nan");
EXPECT_EQ(RoundTripFpToString(
std::numeric_limits<tsl::float8_e4m3fnuz>::quiet_NaN()),
Expand Down Expand Up @@ -249,11 +249,13 @@ TEST(UtilTest, TotalOrder_F8E4M3FN) {

TEST(UtilTest, TotalOrder_F8E4M3B11) {
for (int a = 0; a < 256; ++a) {
tsl::float8_e4m3b11 x =
Eigen::numext::bit_cast<tsl::float8_e4m3b11>(static_cast<uint8_t>(a));
tsl::float8_e4m3b11fnuz x =
Eigen::numext::bit_cast<tsl::float8_e4m3b11fnuz>(
static_cast<uint8_t>(a));
for (int b = 0; b < 256; ++b) {
tsl::float8_e4m3b11 y =
Eigen::numext::bit_cast<tsl::float8_e4m3b11>(static_cast<uint8_t>(b));
tsl::float8_e4m3b11fnuz y =
Eigen::numext::bit_cast<tsl::float8_e4m3b11fnuz>(
static_cast<uint8_t>(b));
TotalOrderHelper(x, y);
}
}
Expand Down

0 comments on commit 2db109a

Please sign in to comment.