Skip to content

Commit

Permalink
[XLA] Copy metadata to combined collectives.
Browse files Browse the repository at this point in the history
This is a naive approach that copies the metadata from the first combined collective operation. We can explore a more complex solution if we encounter use cases where this is not correct.

PiperOrigin-RevId: 719321555
  • Loading branch information
allanrenucci authored and Google-ML-Automation committed Jan 24, 2025
1 parent fcca860 commit 894a70e
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 22 deletions.
2 changes: 2 additions & 0 deletions xla/hlo/transforms/collectives/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ xla_cc_test(
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/utils:hlo_matchers",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
Expand Down Expand Up @@ -310,6 +311,7 @@ xla_cc_test(
"//xla/hlo/utils:hlo_matchers",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:statusor",
],
Expand Down
15 changes: 8 additions & 7 deletions xla/hlo/transforms/collectives/all_gather_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,14 @@ absl::Status CombineAllGathers(absl::Span<HloInstruction* const> to_combine,
}

// Create combined all-gather op with a tuple result.
HloInstruction* combined;
combined = computation.AddInstruction(HloInstruction::CreateAllGather(
ShapeUtil::MakeTupleShape(output_shapes), operands, most_frequent_dim,
to_combine.front()->device_list(),
/*constrain_layout=*/false, to_combine.front()->channel_id(),
Cast<HloAllGatherInstruction>(to_combine.front())
->use_global_device_ids()));
HloInstruction* combined =
computation.AddInstruction(HloInstruction::CreateAllGather(
ShapeUtil::MakeTupleShape(output_shapes), operands, most_frequent_dim,
to_combine.front()->device_list(),
/*constrain_layout=*/false, to_combine.front()->channel_id(),
Cast<HloAllGatherInstruction>(to_combine.front())
->use_global_device_ids()));
combined->set_metadata(to_combine.front()->metadata());

// We have to propagate the sharding manually because Domain instructions are
// not guaranteed to preserve it for side effecting instructions.
Expand Down
31 changes: 31 additions & 0 deletions xla/hlo/transforms/collectives/all_gather_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_casting_utils.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -616,5 +617,35 @@ ENTRY entry {
ASSERT_EQ(1, all_gathers[1]->all_gather_dimension());
}

TEST_F(AllGatherCombinerTest, PreservesMetadata) {
absl::string_view hlo_string = R"(
HloModule Module
ENTRY entry {
param0 = f32[32] parameter(0)
param1 = f32[32] parameter(1)
allgather0 = f32[128] all-gather(param0), replica_groups={}, dimensions={0}, metadata={op_type="test_type0" op_name="test_name0"}
allgather1 = f32[128] all-gather(param1), replica_groups={}, dimensions={0}, metadata={op_type="test_type1" op_name="test_name1"}
ROOT tuple = (f32[128], f32[128]) tuple(allgather0, allgather1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));

AllGatherCombiner combine(1024 * 1024, kMaxCombineCount,
/*combine_by_dim=*/true);
ASSERT_EQ(AllGatherCount(*module), 2);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_TRUE(changed);

OpMetadata metadata;
metadata.set_op_type("test_type0");
metadata.set_op_name("test_name0");
Matcher<const HloInstruction*> combined_all_gather = op::Metadata(metadata);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::GetTupleElement(combined_all_gather, 0),
op::GetTupleElement(combined_all_gather, 1)));
}

} // namespace
} // namespace xla
15 changes: 8 additions & 7 deletions xla/hlo/transforms/collectives/all_reduce_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,16 @@ absl::Status CombineAllReduces(absl::Span<HloInstruction* const> to_combine) {
}
}

HloInstruction* combined;
// AllReduce ops with more than one operand produce a tuple.
TF_RET_CHECK(operands.size() >= 2);
combined = computation.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes), operands, reduction,
to_combine.front()->device_list(),
/*constrain_layout=*/false, to_combine.front()->channel_id(),
Cast<HloAllReduceInstruction>(to_combine.front())
->use_global_device_ids()));
HloInstruction* combined =
computation.AddInstruction(HloInstruction::CreateAllReduce(
ShapeUtil::MakeTupleShapeWithPtrs(operand_shapes), operands,
reduction, to_combine.front()->device_list(),
/*constrain_layout=*/false, to_combine.front()->channel_id(),
Cast<HloAllReduceInstruction>(to_combine.front())
->use_global_device_ids()));
combined->set_metadata(to_combine.front()->metadata());

// We have to propagate the sharding manually because Domain instructions are
// not guaranteed to preserve it for side effecting instructions.
Expand Down
33 changes: 33 additions & 0 deletions xla/hlo/transforms/collectives/all_reduce_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/collective_device_list.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
Expand Down Expand Up @@ -472,5 +473,37 @@ ENTRY %comp {
op::Tuple(op::GetTupleElement(crs1, 0), op::GetTupleElement(crs1, 1)));
}

TEST_F(AllReduceCombinerTest, PreservesMetadata) {
absl::string_view hlo_text = R"(
HloModule Module
%add (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY entry {
%param.0 = f32[32] parameter(0)
%param.1 = f32[32] parameter(1)
%all-reduce.0 = f32[32] all-reduce(%param.0), replica_groups={}, to_apply=%add, metadata={op_type="test_type0" op_name="test_name0"}
%all-reduce.1 = f32[32] all-reduce(%param.1), replica_groups={}, to_apply=%add, metadata={op_type="test_type1" op_name="test_name1"}
ROOT tuple = (f32[32], f32[32]) tuple(%all-reduce.0, %all-reduce.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_text));
AllReduceCombiner combine(1024 * 1024, kMaxCombineCount);
TF_ASSERT_OK_AND_ASSIGN(bool changed, combine.Run(module.get()));
EXPECT_TRUE(changed);
OpMetadata metadata;
metadata.set_op_type("test_type0");
metadata.set_op_name("test_name0");
auto combined_all_reduce = op::Metadata(metadata);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::GetTupleElement(combined_all_reduce, 0),
op::GetTupleElement(combined_all_reduce, 1)));
}

} // namespace
} // namespace xla
1 change: 1 addition & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -2494,6 +2494,7 @@ xla_cc_test(
"//xla/tests:hlo_test_base",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/log",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
],
)
Expand Down
17 changes: 9 additions & 8 deletions xla/service/reduce_scatter_combiner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,15 +138,16 @@ absl::Status CombineReduceScatters(
}

// Create combined scatter-reduce op with a tuple result.
HloInstruction* combined;
TF_RET_CHECK(operands.size() >= 2);
combined = computation.AddInstruction(HloInstruction::CreateReduceScatter(
ShapeUtil::MakeTupleShape(output_shapes), operands, reduction,
to_combine.front()->device_list(),
/*constrain_layout=*/false, to_combine.front()->channel_id(),
Cast<HloReduceScatterInstruction>(to_combine.front())
->use_global_device_ids(),
most_frequent_dim));
HloInstruction* combined =
computation.AddInstruction(HloInstruction::CreateReduceScatter(
ShapeUtil::MakeTupleShape(output_shapes), operands, reduction,
to_combine.front()->device_list(),
/*constrain_layout=*/false, to_combine.front()->channel_id(),
Cast<HloReduceScatterInstruction>(to_combine.front())
->use_global_device_ids(),
most_frequent_dim));
combined->set_metadata(to_combine.front()->metadata());

// We have to propagate the sharding manually because Domain instructions are
// not guaranteed to preserve it for side effecting instructions.
Expand Down
33 changes: 33 additions & 0 deletions xla/service/reduce_scatter_combiner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ limitations under the License.
#include <cstddef>
#include <utility>

#include <gmock/gmock.h>
#include "absl/algorithm/container.h"
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
Expand All @@ -29,6 +31,8 @@ limitations under the License.
namespace xla {
namespace {

namespace op = xla::testing::opcode_matchers;

constexpr int64_t kMaxCombineCount = 256;
constexpr int64_t kMaxByteCount = 10 * 1024 * 1024;

Expand Down Expand Up @@ -349,5 +353,34 @@ ENTRY main {
EXPECT_FALSE(changed);
}

TEST_F(ReduceScatterCombinerTest, PreservesMetadata) {
absl::string_view hlo_string = R"(
HloModule Module
%add (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
ENTRY entry {
%param.0 = f32[32] parameter(0)
%param.1 = f32[32] parameter(1)
%rs.0 = f32[16] reduce-scatter(%param.0), replica_groups={{0,1}}, dimensions={0}, to_apply=%add, metadata={op_type="test_type0" op_name="test_name0"}
%rs.1 = f32[16] reduce-scatter(%param.1), replica_groups={{0,1}}, dimensions={0}, to_apply=%add, metadata={op_type="test_type1" op_name="test_name1"}
ROOT tuple = (f32[16], f32[16]) tuple(%rs.0, %rs.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
RunPass(hlo_string, /*expect_change=*/true));
OpMetadata metadata;
metadata.set_op_type("test_type0");
metadata.set_op_name("test_name0");
auto combined_reduce_scatter = op::Metadata(metadata);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Tuple(op::GetTupleElement(combined_reduce_scatter, 0),
op::GetTupleElement(combined_reduce_scatter, 1)));
}

} // namespace
} // namespace xla

0 comments on commit 894a70e

Please sign in to comment.