Skip to content

Commit

Permalink
[XLA:GPU] Fix reduce scatter transfered bytes.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 713615261
  • Loading branch information
golechwierowicz authored and Google-ML-Automation committed Jan 10, 2025
1 parent c99a4f9 commit 4950971
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 4 deletions.
24 changes: 21 additions & 3 deletions xla/service/gpu/model/gpu_hlo_cost_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,12 +503,13 @@ absl::Status GpuHloCostAnalysis::HandleAsyncStart(const HloInstruction* hlo) {
VLOG(2) << "Only Reduce Scatter is supported.";
return absl::OkStatus();
}
int index_to_skip = 1;
int64_t output_bytes_accessed = 0;
ShapeUtil::ForEachLeafShape(
hlo->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
// Skip first element of a tuple as it expresses the input of the
// collective operation.
if (index.empty() || index.front() == 0) {
// Skip second element of a tuple as it is an output but it is not
// actual bytes transferred.
if (index.empty() || index.front() == index_to_skip) {
return;
}
if (subshape.IsArray()) {
Expand All @@ -520,6 +521,23 @@ absl::Status GpuHloCostAnalysis::HandleAsyncStart(const HloInstruction* hlo) {
return absl::OkStatus();
}

absl::Status GpuHloCostAnalysis::HandleReduceScatter(
const HloInstruction* hlo) {
int64_t output_bytes_accessed = 0;

for (auto* operand : hlo->operands()) {
ShapeUtil::ForEachLeafShape(
operand->shape(), [&](const Shape& subshape, const ShapeIndex& index) {
if (subshape.IsArray()) {
output_bytes_accessed += GetShapeSize(subshape);
}
});
}
current_properties_.set_output_bytes_accessed(output_bytes_accessed);

return absl::OkStatus();
}

absl::Status GpuHloCostAnalysis::HandleElementwiseOp(
const HloInstruction* hlo) {
current_properties_[kFlopsKey] = GetFlopsForElementwiseOp(hlo);
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/model/gpu_hlo_cost_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class GpuHloCostAnalysis : public HloCostAnalysis {
absl::Status HandleAllGather(const HloInstruction* hlo) override;
absl::Status HandleAllGatherStart(const HloInstruction* hlo) override;
absl::Status HandleAsyncStart(const HloInstruction* hlo) override;
absl::Status HandleReduceScatter(const HloInstruction* hlo) override;

// Estimate the total size of IR accounting for both duplication
// of producer code by consumer and the total number of basic blocks.
Expand Down
27 changes: 26 additions & 1 deletion xla/service/gpu/model/gpu_hlo_cost_analysis_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,31 @@ ENTRY entry_computation {
EXPECT_EQ(analysis_.output_bytes_accessed(*all_gather), 4096 * 4 + 2048 * 4);
}

TEST_F(GpuHloCostAnalysisTest, ReduceScatter) {
absl::string_view hlo_string = R"(
HloModule m
add {
param_0 = f32[] parameter(0)
param_1 = f32[] parameter(1)
ROOT t = f32[] add(param_0, param_1)
}
ENTRY entry_computation {
p = f32[4096] parameter(0)
ROOT _ = f32[1024] reduce-scatter(p), dimensions={0}, to_apply=add
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));

ASSERT_IS_OK(module->entry_computation()->Accept(&analysis_));

const HloInstruction* reduce_scatter =
module->entry_computation()->root_instruction();
EXPECT_EQ(analysis_.output_bytes_accessed(*reduce_scatter), 4096 * 4);
}

TEST_F(GpuHloCostAnalysisTest, AsyncReduceScatter) {
absl::string_view hlo_string = R"(
HloModule m
Expand Down Expand Up @@ -743,7 +768,7 @@ ENTRY entry_computation {
module->entry_computation()->root_instruction()->operand(0);
// Output is (f32[1024],f32[512]).
EXPECT_EQ(analysis_.output_bytes_accessed(*reduce_scatter),
1024 * 4 + 512 * 4);
4096 * 4 + 2048 * 4);
}

TEST_F(GpuHloCostAnalysisTest, CustomOpProfileIsUsed) {
Expand Down
31 changes: 31 additions & 0 deletions xla/service/gpu/model/sol_latency_estimator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,41 @@ ENTRY main {
/*expected_latency=*/absl::Microseconds(1323),
};

EstimatorTestCase reduce_scatter_all_ranks = {
/*test_name=*/"reduce_scatter_all_ranks",
/*module_string=*/R"(
HloModule m
add {
param_0 = bf16[] parameter(0)
param_1 = bf16[] parameter(1)
ROOT t = bf16[] add(param_0, param_1)
}
async_comp {
param_3 = bf16[8192,128256] parameter(0)
ROOT r = bf16[64,128256] reduce-scatter(param_3),
dimensions={0},
to_apply=add,
replica_groups=[1,128]<=[128],
channel_id=1,
use_global_device_ids=true
}
ENTRY main {
p = bf16[8192,128256] parameter(0)
rs-start = ((bf16[8192,128256]), bf16[64,128256]) async-start(p), calls=async_comp
ROOT rs-done = bf16[64,128256] async-done(rs-start)
})",
/*opcode=*/HloOpcode::kAsyncStart,
/*expected_latency=*/absl::Microseconds(10525),
};

return {
all_gather_intra_host,
all_gather_inter_host_pairwise,
all_gather_all_ranks,
reduce_scatter_all_ranks,
};
}

Expand Down

0 comments on commit 4950971

Please sign in to comment.