From 4bc7921ec465c0a7a488e468c171c1615a148a53 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Fri, 24 Jan 2025 17:46:23 -0800 Subject: [PATCH] PR #21813: Support e8m0fnu for NCCL collectives Imported from GitHub PR https://github.com/openxla/xla/pull/21813 Support e8m0fnu date type for NCCL collectives. Copybara import of the project: -- 37c5d5baf563b78c7ed0f343b6f1d74c1d9c271b by wenscarl : Support e8m0fnu for NCCL collectives -- cd4f37e1019f053dbe6039953a462de637289ddc by Shu Wang : Add missing data placeholder. Merging this change closes #21813 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/openxla/xla/pull/21813 from wenscarl:e8m0_nccl cd4f37e1019f053dbe6039953a462de637289ddc PiperOrigin-RevId: 719490916 --- .../gpu/collectives/nccl_communicator.cc | 1 + .../gpu/runtime/nccl_collective_thunk.cc | 1 + xla/tests/collective_ops_test.cc | 45 ++++++++++++------- 3 files changed, 31 insertions(+), 16 deletions(-) diff --git a/xla/backends/gpu/collectives/nccl_communicator.cc b/xla/backends/gpu/collectives/nccl_communicator.cc index 17f92e9575d54..32dc200a17d29 100644 --- a/xla/backends/gpu/collectives/nccl_communicator.cc +++ b/xla/backends/gpu/collectives/nccl_communicator.cc @@ -70,6 +70,7 @@ static absl::StatusOr ToNcclDataType(PrimitiveType dtype, case F8E4M3FN: case F8E5M2FNUZ: case F8E4M3FNUZ: + case F8E8M0FNU: return ncclInt8; case PRED: case U8: diff --git a/xla/backends/gpu/runtime/nccl_collective_thunk.cc b/xla/backends/gpu/runtime/nccl_collective_thunk.cc index cefe173dfc67e..addfec329e223 100644 --- a/xla/backends/gpu/runtime/nccl_collective_thunk.cc +++ b/xla/backends/gpu/runtime/nccl_collective_thunk.cc @@ -90,6 +90,7 @@ bool IsTypeSupportedByNccl(PrimitiveType element_type, case F8E4M3FN: case F8E5M2FNUZ: case F8E4M3FNUZ: + case F8E8M0FNU: return !IsReductionCollective(reduction_op); default: return false; diff --git a/xla/tests/collective_ops_test.cc b/xla/tests/collective_ops_test.cc index 8c872afe9d669..4b8e29dd59514 100644 --- a/xla/tests/collective_ops_test.cc +++ b/xla/tests/collective_ops_test.cc @@ -2607,6 +2607,7 @@ class Fp8CollectiveOpsTest : public CollectiveOpsTest { IsCuda() ? "f8e4m3fn" : "f8e4m3fnuz"; replacements_[kF8E5M2DatatypePlaceholder] = IsCuda() ? "f8e5m2" : "f8e5m2fnuz"; + replacements_[kF8E8M0DatatypePlaceholder] = "f8e8m0fnu"; } protected: @@ -2626,33 +2627,45 @@ class Fp8CollectiveOpsTest : public CollectiveOpsTest { private: static constexpr const char* kF8E4M3DatatypePlaceholder{"<>"}; static constexpr const char* kF8E5M2DatatypePlaceholder{"<>"}; + static constexpr const char* kF8E8M0DatatypePlaceholder{"<>"}; }; XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllGather_8BitFloat)) { - const char* const kModuleStr = R"( + const char* const kModuleTemplate = R"( HloModule test ENTRY test_computation { - a0 = <>[1,2] constant({{1,2}}) - allgather = <>[2, 2] all-gather(a0), dimensions={0} - p = <>[4] reshape(allgather) + a0 = <>[1,2] constant({{1,2}}) + allgather = <>[2, 2] all-gather(a0), dimensions={0} + p = <>[4] reshape(allgather) ROOT out = f32[4] convert(p) } )"; const int64_t kNumReplicas = 2; HloModuleConfig config = GetModuleConfigForTest(/*replica_count=*/kNumReplicas); - TF_ASSERT_OK_AND_ASSIGN( - auto module, ParseAndReturnVerifiedModule( - absl::StrReplaceAll(kModuleStr, replacements_), config)); - TF_ASSERT_OK_AND_ASSIGN( - std::vector results, - ExecuteReplicated(std::move(module), absl::Span{}, - kNumReplicas, - /*use_threads=*/true, /*run_hlo_passes=*/true)); - ASSERT_EQ(results.size(), kNumReplicas); - for (const Literal& result : results) { - LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); - } + auto runTestForType = [&](const std::string& type) { + std::string hlo_str = + absl::StrReplaceAll(kModuleTemplate, {{"TYPE", type}}); + + // Parse the HLO module and execute it + TF_ASSERT_OK_AND_ASSIGN( + auto module, ParseAndReturnVerifiedModule( + absl::StrReplaceAll(hlo_str, replacements_), config)); + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), absl::Span{}, + kNumReplicas, /*use_threads=*/true, + /*run_hlo_passes=*/true)); + + // Verify the results + ASSERT_EQ(results.size(), kNumReplicas); + for (const Literal& result : results) { + LiteralTestUtil::ExpectR1Equal({1, 2, 1, 2}, result); + } + }; + runTestForType("F8E8M0"); + runTestForType("F8E4M3"); + runTestForType("F8E5M2"); } XLA_TEST_F(Fp8CollectiveOpsTest, DISABLED_ON_CPU(AllToAll_8BitFloat)) {