From ca7815f1ac58b582a2d5f40bee273d8e82588e69 Mon Sep 17 00:00:00 2001 From: Amit Sabne Date: Thu, 9 Jan 2025 19:07:36 -0800 Subject: [PATCH] Add support for int1 types in literal.cc PiperOrigin-RevId: 713875762 --- xla/literal.cc | 15 ++++++++++++--- xla/xla_data.proto | 4 +++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/xla/literal.cc b/xla/literal.cc index 997f44a4dd0f6..6b5db7f893ec4 100644 --- a/xla/literal.cc +++ b/xla/literal.cc @@ -87,9 +87,10 @@ void ConvertEndianShort(char* bytes, int64_t size) { } bool LiteralProtoHasValues(const LiteralProto& proto) { - return !proto.s2s().empty() || !proto.s4s().empty() || !proto.s8s().empty() || - !proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() || - !proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() || + return !proto.s1s().empty() || !proto.s2s().empty() || !proto.s4s().empty() || + !proto.s8s().empty() || !proto.s16s().empty() || proto.s32s_size() || + proto.s64s_size() || !proto.u1s().empty() || !proto.u2s().empty() || + !proto.u4s().empty() || !proto.u8s().empty() || !proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() || !proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() || !proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() || @@ -2207,6 +2208,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case PRED: CopyToRepeatedField(proto->mutable_preds(), data()); break; + case U1: + *proto->mutable_u1s() = std::string( + reinterpret_cast(data().data()), size_bytes_dense()); + break; case U2: *proto->mutable_u2s() = std::string( reinterpret_cast(data().data()), size_bytes_dense()); @@ -2233,6 +2238,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const { case U64: CopyToRepeatedField(proto->mutable_u64s(), data()); break; + case S1: + *proto->mutable_s1s() = std::string( + reinterpret_cast(data().data()), size_bytes_dense()); + break; case S2: *proto->mutable_s2s() = std::string( reinterpret_cast(data().data()), size_bytes_dense()); diff --git a/xla/xla_data.proto b/xla/xla_data.proto index 3bdf7c6c8cba3..01a6415549b58 100644 --- a/xla/xla_data.proto +++ b/xla/xla_data.proto @@ -562,9 +562,11 @@ message DeviceAssignmentProto { message LiteralProto { ShapeProto shape = 1; repeated bool preds = 2; + bytes s1s = 30; bytes s2s = 26; bytes s4s = 21; bytes s8s = 15; + bytes u1s = 31; bytes u2s = 27; bytes u4s = 22; bytes u8s = 3; @@ -590,7 +592,7 @@ message LiteralProto { bytes f8e4m3fnuzs = 25; bytes f8e3m4s = 29; repeated int64 sparse_indices = 14; - // Next = 30 + // Next = 32 } message WindowDimension {