From e6a6acf0b08460e13e330cf800710e086c7c7d15 Mon Sep 17 00:00:00 2001 From: Greg Olechwierowicz Date: Fri, 10 Jan 2025 01:25:19 -0800 Subject: [PATCH] [XLA:GPU] Fix reduce scatter transfered bytes. PiperOrigin-RevId: 713955797 --- .../gpu/model/gpu_hlo_cost_analysis.cc | 24 ++++++++++++-- xla/service/gpu/model/gpu_hlo_cost_analysis.h | 1 + .../gpu/model/gpu_hlo_cost_analysis_test.cc | 27 +++++++++++++++- .../gpu/model/sol_latency_estimator_test.cc | 31 +++++++++++++++++++ 4 files changed, 79 insertions(+), 4 deletions(-) diff --git a/xla/service/gpu/model/gpu_hlo_cost_analysis.cc b/xla/service/gpu/model/gpu_hlo_cost_analysis.cc index 6462106f95a7b..0461814d6d7c6 100644 --- a/xla/service/gpu/model/gpu_hlo_cost_analysis.cc +++ b/xla/service/gpu/model/gpu_hlo_cost_analysis.cc @@ -503,12 +503,13 @@ absl::Status GpuHloCostAnalysis::HandleAsyncStart(const HloInstruction* hlo) { VLOG(2) << "Only Reduce Scatter is supported."; return absl::OkStatus(); } + int index_to_skip = 1; int64_t output_bytes_accessed = 0; ShapeUtil::ForEachLeafShape( hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) { - // Skip first element of a tuple as it expresses the input of the - // collective operation. - if (index.empty() || index.front() == 0) { + // Skip second element of a tuple as it is an output but it is not + // actual bytes transferred. + if (index.empty() || index.front() == index_to_skip) { return; } if (subshape.IsArray()) { @@ -520,6 +521,23 @@ absl::Status GpuHloCostAnalysis::HandleAsyncStart(const HloInstruction* hlo) { return absl::OkStatus(); } +absl::Status GpuHloCostAnalysis::HandleReduceScatter( + const HloInstruction* hlo) { + int64_t output_bytes_accessed = 0; + + for (auto* operand : hlo->operands()) { + ShapeUtil::ForEachLeafShape( + operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) { + if (subshape.IsArray()) { + output_bytes_accessed += GetShapeSize(subshape); + } + }); + } + current_properties_.set_output_bytes_accessed(output_bytes_accessed); + + return absl::OkStatus(); +} + absl::Status GpuHloCostAnalysis::HandleElementwiseOp( const HloInstruction* hlo) { current_properties_[kFlopsKey] = GetFlopsForElementwiseOp(hlo); diff --git a/xla/service/gpu/model/gpu_hlo_cost_analysis.h b/xla/service/gpu/model/gpu_hlo_cost_analysis.h index 64cb9db1d1a70..5561a321b318e 100644 --- a/xla/service/gpu/model/gpu_hlo_cost_analysis.h +++ b/xla/service/gpu/model/gpu_hlo_cost_analysis.h @@ -76,6 +76,7 @@ class GpuHloCostAnalysis : public HloCostAnalysis { absl::Status HandleAllGather(const HloInstruction* hlo) override; absl::Status HandleAllGatherStart(const HloInstruction* hlo) override; absl::Status HandleAsyncStart(const HloInstruction* hlo) override; + absl::Status HandleReduceScatter(const HloInstruction* hlo) override; // Estimate the total size of IR accounting for both duplication // of producer code by consumer and the total number of basic blocks. diff --git a/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc b/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc index 3cabadfd6aab6..71b7da2332e30 100644 --- a/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc +++ b/xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc @@ -709,6 +709,31 @@ ENTRY entry_computation { EXPECT_EQ(analysis_.output_bytes_accessed(*all_gather), 4096 * 4 + 2048 * 4); } +TEST_F(GpuHloCostAnalysisTest, ReduceScatter) { + absl::string_view hlo_string = R"( +HloModule m + +add { + param_0 = f32[] parameter(0) + param_1 = f32[] parameter(1) + ROOT t = f32[] add(param_0, param_1) +} + +ENTRY entry_computation { + p = f32[4096] parameter(0) + ROOT _ = f32[1024] reduce-scatter(p), dimensions={0}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + + ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_)); + + const HloInstruction* reduce_scatter = + module->entry_computation()->root_instruction(); + EXPECT_EQ(analysis_.output_bytes_accessed(*reduce_scatter), 4096 * 4); +} + TEST_F(GpuHloCostAnalysisTest, AsyncReduceScatter) { absl::string_view hlo_string = R"( HloModule m @@ -743,7 +768,7 @@ ENTRY entry_computation { module->entry_computation()->root_instruction()->operand(0); // Output is (f32[1024],f32[512]). EXPECT_EQ(analysis_.output_bytes_accessed(*reduce_scatter), - 1024 * 4 + 512 * 4); + 4096 * 4 + 2048 * 4); } TEST_F(GpuHloCostAnalysisTest, CustomOpProfileIsUsed) { diff --git a/xla/service/gpu/model/sol_latency_estimator_test.cc b/xla/service/gpu/model/sol_latency_estimator_test.cc index 7399030d895de..de40364d29f88 100644 --- a/xla/service/gpu/model/sol_latency_estimator_test.cc +++ b/xla/service/gpu/model/sol_latency_estimator_test.cc @@ -137,10 +137,41 @@ ENTRY main { /*expected_latency=*/absl::Microseconds(1323), }; + EstimatorTestCase reduce_scatter_all_ranks = { + /*test_name=*/"reduce_scatter_all_ranks", + /*module_string=*/R"( +HloModule m + +add { + param_0 = bf16[] parameter(0) + param_1 = bf16[] parameter(1) + ROOT t = bf16[] add(param_0, param_1) +} + +async_comp { + param_3 = bf16[8192,128256] parameter(0) + ROOT r = bf16[64,128256] reduce-scatter(param_3), + dimensions={0}, + to_apply=add, + replica_groups=[1,128]<=[128], + channel_id=1, + use_global_device_ids=true +} + +ENTRY main { + p = bf16[8192,128256] parameter(0) + rs-start = ((bf16[8192,128256]), bf16[64,128256]) async-start(p), calls=async_comp + ROOT rs-done = bf16[64,128256] async-done(rs-start) +})", + /*opcode=*/HloOpcode::kAsyncStart, + /*expected_latency=*/absl::Microseconds(10525), + }; + return { all_gather_intra_host, all_gather_inter_host_pairwise, all_gather_all_ranks, + reduce_scatter_all_ranks, }; }