Skip to content

Commit

Permalink
Add support for int1 types in literal.cc
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713875762
  • Loading branch information
amitsabne1 authored and Google-ML-Automation committed Jan 10, 2025
1 parent 066be4e commit ca7815f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
15 changes: 12 additions & 3 deletions xla/literal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ void ConvertEndianShort(char* bytes, int64_t size) {
}

bool LiteralProtoHasValues(const LiteralProto& proto) {
return !proto.s2s().empty() || !proto.s4s().empty() || !proto.s8s().empty() ||
!proto.s16s().empty() || proto.s32s_size() || proto.s64s_size() ||
!proto.u2s().empty() || !proto.u4s().empty() || !proto.u8s().empty() ||
return !proto.s1s().empty() || !proto.s2s().empty() || !proto.s4s().empty() ||
!proto.s8s().empty() || !proto.s16s().empty() || proto.s32s_size() ||
proto.s64s_size() || !proto.u1s().empty() || !proto.u2s().empty() ||
!proto.u4s().empty() || !proto.u8s().empty() ||
!proto.u16s().empty() || proto.u32s_size() || proto.u64s_size() ||
!proto.f8e5m2s().empty() || !proto.f8e4m3s().empty() ||
!proto.f8e4m3fns().empty() || !proto.f8e4m3b11fnuzs().empty() ||
Expand Down Expand Up @@ -2207,6 +2208,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
case PRED:
CopyToRepeatedField(proto->mutable_preds(), data<bool>());
break;
case U1:
*proto->mutable_u1s() = std::string(
reinterpret_cast<const char*>(data<u1>().data()), size_bytes_dense());
break;
case U2:
*proto->mutable_u2s() = std::string(
reinterpret_cast<const char*>(data<u2>().data()), size_bytes_dense());
Expand All @@ -2233,6 +2238,10 @@ void LiteralBase::Piece::WriteToProto(LiteralProto* proto) const {
case U64:
CopyToRepeatedField(proto->mutable_u64s(), data<uint64_t>());
break;
case S1:
*proto->mutable_s1s() = std::string(
reinterpret_cast<const char*>(data<s1>().data()), size_bytes_dense());
break;
case S2:
*proto->mutable_s2s() = std::string(
reinterpret_cast<const char*>(data<s2>().data()), size_bytes_dense());
Expand Down
4 changes: 3 additions & 1 deletion xla/xla_data.proto
Original file line number Diff line number Diff line change
Expand Up @@ -562,9 +562,11 @@ message DeviceAssignmentProto {
message LiteralProto {
ShapeProto shape = 1;
repeated bool preds = 2;
bytes s1s = 30;
bytes s2s = 26;
bytes s4s = 21;
bytes s8s = 15;
bytes u1s = 31;
bytes u2s = 27;
bytes u4s = 22;
bytes u8s = 3;
Expand All @@ -590,7 +592,7 @@ message LiteralProto {
bytes f8e4m3fnuzs = 25;
bytes f8e3m4s = 29;
repeated int64 sparse_indices = 14;
// Next = 30
// Next = 32
}

message WindowDimension {
Expand Down

0 comments on commit ca7815f

Please sign in to comment.