Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PR #21813: Support e8m0fnu for NCCL collectives #21836

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/backends/gpu/collectives/nccl_communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ static absl::StatusOr<ncclDataType_t> ToNcclDataType(PrimitiveType dtype,
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
case F8E8M0FNU:
return ncclInt8;
case PRED:
case U8:
Expand Down
1 change: 1 addition & 0 deletions xla/backends/gpu/runtime/nccl_collective_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type,
case F8E4M3FN:
case F8E5M2FNUZ:
case F8E4M3FNUZ:
case F8E8M0FNU:
return !IsReductionCollective(reduction_op);
default:
return false;
Expand Down
43 changes: 27 additions & 16 deletions xla/tests/collective_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2629,30 +2629,41 @@ class Fp8CollectiveOpsTest : public CollectiveOpsTest {
};

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) {
const char* const kModuleStr = R"(
const char* const kModuleTemplate = R"(
HloModule test
ENTRY test_computation {
a0 = <<F8E4M3>>[1,2] constant({{1,2}})
allgather = <<F8E4M3>>[2, 2] all-gather(a0), dimensions={0}
p = <<F8E4M3>>[4] reshape(allgather)
a0 = <<TYPE>>[1,2] constant({{1,2}})
allgather = <<TYPE>>[2, 2] all-gather(a0), dimensions={0}
p = <<TYPE>>[4] reshape(allgather)
ROOT out = f32[4] convert(p)
}
)";
const int64_t kNumReplicas = 2;
HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas);
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(kModuleStr, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas,
/*use_threads=*/true, /*run_hlo_passes=*/true));
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
auto runTestForType = [&](const std::string& type) {
std::string hlo_str =
absl::StrReplaceAll(kModuleTemplate, {{"TYPE", type}});

// Parse the HLO module and execute it
TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(
absl::StrReplaceAll(hlo_str, replacements_), config));
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), absl::Span<Literal* const>{},
kNumReplicas, /*use_threads=*/true,
/*run_hlo_passes=*/true));

// Verify the results
ASSERT_EQ(results.size(), kNumReplicas);
for (const Literal& result : results) {
LiteralTestUtil::ExpectR1Equal<float>({1, 2, 1, 2}, result);
}
};
runTestForType("F8E8M0");
runTestForType("F8E4M3");
runTestForType("F8E5M2");
}

XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {
Expand Down