Skip to content

Commit

Permalink
[XLA:SchedulingAnnotations] Handle instructions with control dependen…
Browse files Browse the repository at this point in the history
…cies.

PiperOrigin-RevId: 713805758
  • Loading branch information
seherellis authored and Google-ML-Automation committed Jan 9, 2025
1 parent 5a6ef8a commit 5ceca0e
Showing 4 changed files with 84 additions and 33 deletions.
5 changes: 3 additions & 2 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
@@ -1191,6 +1191,7 @@ cc_library(
"//xla/hlo/analysis:hlo_alias_analysis",
"//xla/hlo/analysis:hlo_reachability",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:ptrvec",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
@@ -1203,8 +1204,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:status",
"@tsl//tsl/platform:statusor",
],
)
@@ -6425,6 +6424,7 @@ cc_library(
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:ptrvec",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
@@ -6448,6 +6448,7 @@ xla_cc_test(
"//xla:util",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/testlib:test_helpers",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
12 changes: 8 additions & 4 deletions xla/service/latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
@@ -43,6 +43,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/ir/ptrvec.h"
#include "xla/hlo/pass/hlo_pass_interface.h"
#include "xla/map_util.h"
#include "xla/service/hlo_buffer.h"
@@ -377,10 +378,13 @@ class AnnotationTracker {
annotations_[annotation].begin(), annotations_[annotation].end());
for (const HloInstruction* instr : annotations_.at(annotation)) {
bool has_annotated_user = false;
for (HloInstruction* user : instr->users()) {
if (seen_instructions.contains(user)) {
has_annotated_user = true;
break;
for (const PtrVec<HloInstruction*>& users :
{instr->users(), instr->control_successors()}) {
for (HloInstruction* user : users) {
if (seen_instructions.contains(user)) {
has_annotated_user = true;
break;
}
}
}
if (!has_annotated_user) {
61 changes: 34 additions & 27 deletions xla/service/legalize_scheduling_annotations.cc
Original file line number Diff line number Diff line change
@@ -32,6 +32,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/ptrvec.h"
#include "xla/side_effect_util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/statusor.h"
@@ -183,14 +184,17 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
"Done instruction's operand is not annotated with the same id: ",
instr->operand(0)->name(), ", annotation: ", id));
}
for (HloInstruction* user : instr->users()) {
if (!visited.contains(user) &&
(!annotation.contains(user) || annotation[user] != id)) {
stack.push_back(user);
parent[user] = instr;
visited.insert(user);
VLOG(2) << "Annotation group: " << id
<< ", frontier using a root: " << user->name();
for (const PtrVec<HloInstruction*>& users :
{instr->users(), instr->control_successors()}) {
for (HloInstruction* user : users) {
if (!visited.contains(user) &&
(!annotation.contains(user) || annotation[user] != id)) {
stack.push_back(user);
parent[user] = instr;
visited.insert(user);
VLOG(2) << "Annotation group: " << id
<< ", frontier using a root: " << user->name();
}
}
}
}
@@ -202,28 +206,31 @@ absl::StatusOr<bool> LegalizeSchedulingAnnotations::Run(
while (!stack.empty()) {
HloInstruction* instr = stack.back();
stack.pop_back();
for (HloInstruction* user : instr->users()) {
if (annotation.contains(user) && annotation[user] == id) {
LOG(INFO) << "PATH: " << user->name();
HloInstruction* current = instr;
LOG(INFO) << "PATH: " << current->name();
while (parent.contains(current)) {
current = parent[current];
for (const PtrVec<HloInstruction*>& users :
{instr->users(), instr->control_successors()}) {
for (HloInstruction* user : users) {
if (annotation.contains(user) && annotation[user] == id) {
LOG(INFO) << "PATH: " << user->name();
HloInstruction* current = instr;
LOG(INFO) << "PATH: " << current->name();
while (parent.contains(current)) {
current = parent[current];
LOG(INFO) << "PATH: " << current->name();
}
return absl::UnimplementedError(absl::StrCat(
"Support for annotation groups with gaps doesn't "
"exist yet, annotation: ",
id, ", instr: ", user->name(),
" has the same annotation in its operand tree but "
"has gaps on the way from that operand to itself."));
}
return absl::UnimplementedError(
absl::StrCat("Support for annotation groups with gaps doesn't "
"exist yet, annotation: ",
id, ", instr: ", user->name(),
" has the same annotation in its operand tree but "
"has gaps on the way from that operand to itself."));
}
if (visited.contains(user)) {
continue;
if (visited.contains(user)) {
continue;
}
stack.push_back(user);
parent[user] = instr;
visited.insert(user);
}
stack.push_back(user);
parent[user] = instr;
visited.insert(user);
}
}
}
39 changes: 39 additions & 0 deletions xla/service/legalize_scheduling_annotations_test.cc
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/hlo/ir/hlo_schedule.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/hlo/testlib/test_helpers.h"
#include "xla/side_effect_util.h"
#include "xla/test_helpers.h"
#include "xla/tsl/platform/statusor.h"
@@ -281,5 +282,43 @@ TEST_F(LegalizeSchedulingAnnotationsTest, DropAnnotationFromBitcast) {
bitcast->frontend_attributes().map().contains(kXlaSchedulingGroupIdAttr));
}

TEST_F(LegalizeSchedulingAnnotationsTest, OpsWithControlDependencies) {
constexpr absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
ENTRY entry {
p0 = f32[16,64,256]{2,1,0} parameter(0)
p2 = f32[512,2048,2048]{2,1,0} parameter(2)
after-all = token[] after-all()
send = (f32[512,2048,2048]{2,1,0}, u32[], token[]) send(p2, after-all), channel_id=1
send-done = token[] send-done(send), channel_id=1
recv = (f32[512,2048,2048]{2,1,0}, u32[], token[]) recv(after-all), channel_id=2
recv-done = (f32[512,2048,2048]{2,1,0}, token[]) recv-done(recv), channel_id=2, control-predecessors={send-done}
get-tuple-element = f32[512,2048,2048]{2,1,0} get-tuple-element(recv-done), index=0
slice = f32[16,64,256]{2,1,0} slice(get-tuple-element), slice={[0:16], [0:64], [0:256]}
c0 = f32[16,256,256]{2,1,0} convolution(p0, slice), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb
c1 = f32[16,256,256]{2,1,0} convolution(p0, slice), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="0"}
p1 = f32[128,2048,2048]{2,1,0} parameter(1)
after-all.1 = token[] after-all()
send.1 = (f32[128,2048,2048]{2,1,0}, u32[], token[]) send(p1, after-all.1), channel_id=3, frontend_attributes={_scheduling_group_id="0"}
send-done.1 = token[] send-done(send.1), channel_id=3, frontend_attributes={_scheduling_group_id="0"}
recv.1 = (f32[128,2048,2048]{2,1,0}, u32[], token[]) recv(after-all.1), channel_id=4, frontend_attributes={_scheduling_group_id="0"}
recv-done.1 = (f32[128,2048,2048]{2,1,0}, token[]) recv-done(recv.1), channel_id=4, frontend_attributes={_scheduling_group_id="0"}, control-predecessors={send-done.1}
get-tuple-element.1 = f32[128,2048,2048]{2,1,0} get-tuple-element(recv-done.1), index=0
after-all.2 = token[] after-all()
send.2 = (f32[128,2048,2048]{2,1,0}, u32[], token[]) send(get-tuple-element.1, after-all.2), channel_id=5
send-done.2 = token[] send-done(send.2), channel_id=5
recv.2 = (f32[128,2048,2048]{2,1,0}, u32[], token[]) recv(after-all.2), channel_id=6
recv-done.2 = (f32[128,2048,2048]{2,1,0}, token[]) recv-done(recv.2), channel_id=6, control-predecessors={send-done.2}
get-tuple-element.2 = f32[128,2048,2048]{2,1,0} get-tuple-element(recv-done.2), index=0
ROOT tuple.2 = (f32[16,256,256]{2,1,0}, f32[16,256,256]{2,1,0}, f32[128,2048,2048]{2,1,0}, f32[128,2048,2048]{2,1,0}) tuple(c0, c1, get-tuple-element.1, get-tuple-element.2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> hlo_module,
ParseAndReturnVerifiedModule(hlo_string));
LegalizeSchedulingAnnotations::Config config;
EXPECT_IS_OK(
LegalizeSchedulingAnnotations(config).Run(hlo_module.get()).status());
}
} // namespace
} // namespace xla

0 comments on commit 5ceca0e

Please sign in to comment.