Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA] Support nested fusions in HloFusionAdaptor #21258

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 36 additions & 13 deletions xla/hlo/utils/hlo_traversal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
explicit HloComputationFusion(const HloComputation* computation,
const HloFusionAdaptor* parent)
: computation_(computation), parent_(parent) {
// `FindNonTrivialHero` only call `ContainsInstruction` and doesn't use
// `FindNonTrivialHero` only calls `ContainsInstruction` and doesn't use
// information about roots, so we can skip looking for roots as performance
// optimization.
// TODO(shyshkov): Clean this up once priority fusion is fully launched.
Expand Down Expand Up @@ -177,12 +177,21 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
}

bool ContainsInstruction(const HloInstruction* instruction) const override {
return instruction->parent() == computation_ ||
// For convenience, we consider that the adaptor also contains the
// parent fusion instruction. This is useful in
// ResolveUsers/ResolveOperand to check if the given fusion
// instruction is part of the fusion adaptor.
instruction == computation_->FusionInstruction();
// For convenience, we consider that the adaptor also contains the parent
// fusion instruction. This is useful in ResolveUsers/ResolveOperand to
// check if the given fusion instruction is part of the fusion adaptor.
if (instruction == computation_->FusionInstruction()) {
return true;
}
// Check whether the recursive parent computation of the given 'instruction'
// is equal to this fusion computation.
do {
if (instruction->parent() == computation_) {
return true;
}
instruction = instruction->parent()->FusionInstruction();
} while (instruction != nullptr);
return false;
}

absl::InlinedVector<HloInstructionAdaptor, 2> GetRoots() const override {
Expand Down Expand Up @@ -221,6 +230,14 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
(instr->opcode() == HloOpcode::kTuple && instr->IsRoot())) {
continue;
}
if (instr->opcode() == HloOpcode::kFusion) {
// Recurse into nested fusions.
HloComputationFusion nested_fusion(
instr->fused_instructions_computation(), parent_);
absl::c_move(nested_fusion.MakeInstructionPostOrder(),
std::back_inserter(result));
continue;
}
result.emplace_back(*instr, parent_);
}
return result;
Expand All @@ -236,6 +253,13 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
instr->opcode() == HloOpcode::kGetTupleElement) {
continue;
}
if (instr->opcode() == HloOpcode::kFusion) {
// Recurse into nested fusions.
HloComputationFusion nested_fusion(
instr->fused_instructions_computation(), parent_);
nested_fusion.ForEach(fn);
continue;
}
fn(HloInstructionAdaptor{*instr, parent_});
}
}
Expand Down Expand Up @@ -284,10 +308,9 @@ bool HloFusionAdaptor::ContainsInstruction(

bool HloFusionAdaptor::ContainsInstruction(
const HloInstruction* instruction) const {
for (const auto& fusion_instruction : fusion_instructions_) {
if (fusion_instruction->ContainsInstruction(instruction)) return true;
}
return false;
return absl::c_any_of(fusion_instructions_, [&](const auto& adaptor) {
return adaptor->ContainsInstruction(instruction);
});
}

absl::InlinedVector<HloInstructionAdaptor, 2> HloFusionAdaptor::GetRoots()
Expand Down Expand Up @@ -427,8 +450,8 @@ void HloFusionAdaptor::ForEach(

std::string HloFusionAdaptor::ToString() const {
std::ostringstream ss;
for (const auto& fusion_instruction : fusion_instructions_) {
ss << fusion_instruction->ToString() << "\n";
for (const auto& fusion_instruction : MakeInstructionPostOrder()) {
ss << fusion_instruction.ToString() << "\n";
}
return ss.str();
}
Expand Down
12 changes: 10 additions & 2 deletions xla/hlo/utils/hlo_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,29 @@ bool IsOpcodeAnyOf(const HloInstruction* instr) {

namespace internal {

// An interface to abstract away the difference between single instruction
// fusion and fused computations.
// An interface to abstract away the difference between a single instruction
// and a fusion instruction with all it's (potentially nbested) computations.
class HloFusionInstructionAdaptor {
public:
virtual ~HloFusionInstructionAdaptor() = default;
// Returns true if the given 'instruction' is either the adapted instruction
// or contained in one of its nested computations.
virtual bool ContainsInstruction(const HloInstruction* instruction) const = 0;
// If it is a regular multi-output fusion, the order of the returned roots
// matches the order of the tuple elements of the tuple root of the fusion
// computation. We do not deduplicate fusion roots.
virtual absl::InlinedVector<HloInstructionAdaptor, 2> GetRoots() const = 0;
// Returns the operands of the adapted instruction.
virtual absl::InlinedVector<const HloInstruction*, 2> GetParameters()
const = 0;
// Returns the adapted instruction.
virtual const HloInstruction& FusionInstruction() const = 0;
// Returns the single instruction or the instructions of the (potentially
// nested) computations, in post order.
virtual absl::InlinedVector<HloInstructionAdaptor, 2>
MakeInstructionPostOrder() const = 0;
// Calls 'fn' the single instruction or all instructions in the (potentially
// nested) computations, in some order.
virtual void ForEach(
const std::function<void(HloInstructionAdaptor)>& fn) const = 0;
virtual std::string ToString() const = 0;
Expand Down
31 changes: 31 additions & 0 deletions xla/hlo/utils/hlo_traversal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,37 @@ TEST_F(HloTraversalTest, AdaptorUsers) {
EXPECT_TRUE(neg.GetUsers().empty());
}

TEST_F(HloTraversalTest, NestedFusionIsTraversedCorrectly) {
auto module = ParseAndReturnVerifiedModule(
R"(
inner {
p0 = f32[] parameter(0)
ROOT mul = f32[] multiply(p0, p0)
}

outer {
p0 = f32[] parameter(0)
inner = f32[] fusion(p0), kind=kLoop, calls=inner
ROOT neg = f32[] negate(inner)
}

ENTRY entry {
p0 = f32[] parameter(0)
ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=outer
}
)")
.value();

auto fusion_adaptor = HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction());

HloInstructionAdaptor negate_instruction = fusion_adaptor->GetRoots()[0];

EXPECT_THAT(negate_instruction, InstructionAdaptorName("neg"));
EXPECT_THAT(negate_instruction.GetOperands(),
ElementsAre(InstructionAdaptorName("mul")));
}

TEST_F(HloTraversalTest, TraverseFusionConsumerFirst) {
auto module = ParseAndReturnVerifiedModule(kTestModule).value();
std::vector<std::string> visited_nodes;
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/transforms/nest_gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,14 @@ ENTRY entry_computation {
kind=kCustom, calls=dot, backend_config={
"fusion_backend_config":{
"kind":"__triton_gemm","triton_gemm_config":{
"block_m":"16","block_n":"16","block_k":"32",
"block_m":"4","block_n":"16","block_k":"128",
"split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"
}
}
}
}
)";
// Note: block sizes were 16,16,32, but that now fails to satisfy constraints.
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true));
TF_ASSERT_OK(verifier().Run(module.get()).status());
Expand Down
Loading