Skip to content

Commit

Permalink
PR #20744: [NVIDIA GPU] Add a flag to control a2a collective matmul r…
Browse files Browse the repository at this point in the history
…ewrite

Imported from GitHub PR #20744

This is address the revert in #19451 where customers see MFU when enabling collective matmul by default. The a2a collective matmul kicks in by default on some small gemms and lead to inefficient transformation.
Adding a flag to disable it by default since it's experimental.
Copybara import of the project:

--
f3d3208 by TJ Xu <[email protected]>:

Add a flag to control a2a collective matmul rewrite

--
0068abc by TJ Xu <[email protected]>:

added more comment for the new flag

--
9f88fe9 by TJ Xu <[email protected]>:

add flag to debug options

Merging this change closes #20744

COPYBARA_INTEGRATE_REVIEW=#20744 from Tixxx:tixxx/add_flag_a2a_gemm 9f88fe9
PiperOrigin-RevId: 713973994
  • Loading branch information
Tixxx authored and Google-ML-Automation committed Jan 10, 2025
1 parent b87c09a commit ba0410b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 5 deletions.
9 changes: 9 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2186,6 +2186,15 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_gpu_unsupported_enable_ragged_all_to_all_decomposer(),
"Internal: Enable the RaggedAllToAllDecomposer, an experimental pass "
"that rewrites ragged-all-to-all as a dense all-to-all operation."));
flag_list->push_back(tsl::Flag(
"xla_gpu_experimental_enable_alltoall_windowed_einsum",
bool_setter_for(
&DebugOptions::
set_xla_gpu_experimental_enable_alltoall_windowed_einsum),
debug_options->xla_gpu_experimental_enable_alltoall_windowed_einsum(),
"Enable windowed einsum rewrite for all-to-all+gemm pattern, "
"This optimization slices the all-to-all into smaller all-to-alls."
"It is an experimental feature."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
12 changes: 12 additions & 0 deletions xla/service/gpu/transforms/windowed_einsum_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,12 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
// Rewrites an all-to-all+gemm into multiple independent partial a2a+gemms
// to minimize communication overhead. To do this, the original input will
// be sliced into replica_group size and perform all-to-all+gemm.
if (!dot->GetModule()
->config()
.debug_options()
.xla_gpu_experimental_enable_alltoall_windowed_einsum()) {
return absl::OkStatus();
}
HloInstruction* lhs;
HloInstruction* rhs;
std::vector<xla::ReplicaGroup> replica_groups;
Expand Down Expand Up @@ -1183,6 +1189,12 @@ class WindowedEinsumVisitor : public DfsHloRewriteVisitor {
absl::Status HandleAllToAll(HloInstruction* inst) override {
CHECK_EQ(inst->opcode(), HloOpcode::kAllToAll);
HloComputation* comp = inst->parent();
if (!inst->GetModule()
->config()
.debug_options()
.xla_gpu_experimental_enable_alltoall_windowed_einsum()) {
return absl::OkStatus();
}
// Rewrites a gemm+alltoall into multiple independent partial gemm+a2as
// to minimize communication overhead.
std::vector<xla::ReplicaGroup> replica_groups;
Expand Down
12 changes: 12 additions & 0 deletions xla/service/gpu/transforms/windowed_einsum_handler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,

WindowedEinsumHandler gpu_handler;
bool changed;
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(true);
TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
RunFileCheck(module->ToString(), kExpected));
Expand Down Expand Up @@ -459,6 +462,9 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,8192]{3,2,1,0} add(bf16[1,4,2048,8192]{3,2,1,

WindowedEinsumHandler gpu_handler;
bool changed;
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(true);
TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
RunFileCheck(module->ToString(), kExpected));
Expand Down Expand Up @@ -541,6 +547,9 @@ CHECK: ROOT {{.*}} = bf16[1,4,2048,32768]{3,2,1,0} add(bf16[1,4,2048,32768]{3,2,

WindowedEinsumHandler gpu_handler;
bool changed;
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(true);
TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
Expand Down Expand Up @@ -625,6 +634,9 @@ CHECK: ROOT {{.*}} = bf16[1,4,1,1,2048,8192]{5,4,3,2,1,0} reshape(bf16[1,4,1,204

WindowedEinsumHandler gpu_handler;
bool changed;
module->mutable_config()
.mutable_debug_options()
.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(true);
TF_ASSERT_OK_AND_ASSIGN(changed, gpu_handler.Run(module.get()));
EXPECT_TRUE(changed);
TF_ASSERT_OK_AND_ASSIGN(bool filecheck_matched,
Expand Down
17 changes: 13 additions & 4 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,8 @@ TEST_F(CollectiveOpsTestE2E, NoAllToAllDecomposition) {
class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
public:
void CollectiveOpsCompareWindowedNonWindowed(
absl::string_view hlo_text, bool disable_dot_merger = false) {
absl::string_view hlo_text, bool disable_dot_merger = false,
bool enable_a2a_rewrite = false) {
const int64_t kNumReplicas = 1;
const int64_t kNumPartitions = 4;
if (test_runner().device_count() < kNumReplicas * kNumPartitions) {
Expand All @@ -825,6 +826,8 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0);
opts.set_xla_gpu_multi_streamed_windowed_einsum(true);
opts.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(
enable_a2a_rewrite);
opts.set_xla_gpu_graph_min_graph_size(200);
opts.set_xla_gpu_enable_triton_gemm(false);
if (disable_dot_merger) {
Expand Down Expand Up @@ -1098,7 +1101,9 @@ ENTRY main.9_spmd {
}
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr);
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/false,
/*enable_a2a_rewrite=*/true);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
Expand All @@ -1114,7 +1119,9 @@ ENTRY main.9_spmd {
}
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr);
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/false,
/*enable_a2a_rewrite=*/true);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
Expand All @@ -1135,7 +1142,9 @@ ENTRY main.9_spmd {
}
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr);
CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/false,
/*enable_a2a_rewrite=*/true);
}

TEST_F(CollectiveOpsTestE2E, CollectivePipelinerF8) {
Expand Down
7 changes: 6 additions & 1 deletion xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1101,7 +1101,12 @@ message DebugOptions {
// be deterministic, although with additional overhead.
bool xla_gpu_enable_scatter_determinism_expander = 345;

// Next id: 360
// Enable windowed einsum(collective matmul) rewrite for all-to-all + gemm
// This feature is still experimental and effective only
// xla_gpu_multi_streamed_windowed_einsum is set to true.
bool xla_gpu_experimental_enable_alltoall_windowed_einsum = 360;

// Next id: 361

// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.
Expand Down

0 comments on commit ba0410b

Please sign in to comment.