Skip to content

Commit

Permalink
[XLA] Support nested fusions in HloFusionAdaptor
Browse files Browse the repository at this point in the history
So far, we assumed that `HloComputationFusion` itself contains no fusion instructions. Adding support for that is one step towards a generic Triton emitter that uses nested fusions for the operands of some specific ops (`dot`, `reduce` and potentially `concat`).

PiperOrigin-RevId: 714900385
  • Loading branch information
chsigg authored and Google-ML-Automation committed Jan 13, 2025
1 parent 0d531ac commit cdd12c2
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 16 deletions.
49 changes: 36 additions & 13 deletions xla/hlo/utils/hlo_traversal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
explicit HloComputationFusion(const HloComputation* computation,
const HloFusionAdaptor* parent)
: computation_(computation), parent_(parent) {
// `FindNonTrivialHero` only call `ContainsInstruction` and doesn't use
// `FindNonTrivialHero` only calls `ContainsInstruction` and doesn't use
// information about roots, so we can skip looking for roots as performance
// optimization.
// TODO(shyshkov): Clean this up once priority fusion is fully launched.
Expand Down Expand Up @@ -177,12 +177,21 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
}

bool ContainsInstruction(const HloInstruction* instruction) const override {
return instruction->parent() == computation_ ||
// For convenience, we consider that the adaptor also contains the
// parent fusion instruction. This is useful in
// ResolveUsers/ResolveOperand to check if the given fusion
// instruction is part of the fusion adaptor.
instruction == computation_->FusionInstruction();
// For convenience, we consider that the adaptor also contains the parent
// fusion instruction. This is useful in ResolveUsers/ResolveOperand to
// check if the given fusion instruction is part of the fusion adaptor.
if (instruction == computation_->FusionInstruction()) {
return true;
}
// Check whether the recursive parent computation of the given 'instruction'
// is equal to this fusion computation.
do {
if (instruction->parent() == computation_) {
return true;
}
instruction = instruction->parent()->FusionInstruction();
} while (instruction != nullptr);
return false;
}

absl::InlinedVector<HloInstructionAdaptor, 2> GetRoots() const override {
Expand Down Expand Up @@ -221,6 +230,14 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
(instr->opcode() == HloOpcode::kTuple && instr->IsRoot())) {
continue;
}
if (instr->opcode() == HloOpcode::kFusion) {
// Recurse into nested fusions.
HloComputationFusion nested_fusion(
instr->fused_instructions_computation(), parent_);
absl::c_move(nested_fusion.MakeInstructionPostOrder(),
std::back_inserter(result));
continue;
}
result.emplace_back(*instr, parent_);
}
return result;
Expand All @@ -236,6 +253,13 @@ class HloComputationFusion : public internal::HloFusionInstructionAdaptor {
instr->opcode() == HloOpcode::kGetTupleElement) {
continue;
}
if (instr->opcode() == HloOpcode::kFusion) {
// Recurse into nested fusions.
HloComputationFusion nested_fusion(
instr->fused_instructions_computation(), parent_);
nested_fusion.ForEach(fn);
continue;
}
fn(HloInstructionAdaptor{*instr, parent_});
}
}
Expand Down Expand Up @@ -284,10 +308,9 @@ bool HloFusionAdaptor::ContainsInstruction(

bool HloFusionAdaptor::ContainsInstruction(
const HloInstruction* instruction) const {
for (const auto& fusion_instruction : fusion_instructions_) {
if (fusion_instruction->ContainsInstruction(instruction)) return true;
}
return false;
return absl::c_any_of(fusion_instructions_, [&](const auto& adaptor) {
return adaptor->ContainsInstruction(instruction);
});
}

absl::InlinedVector<HloInstructionAdaptor, 2> HloFusionAdaptor::GetRoots()
Expand Down Expand Up @@ -427,8 +450,8 @@ void HloFusionAdaptor::ForEach(

std::string HloFusionAdaptor::ToString() const {
std::ostringstream ss;
for (const auto& fusion_instruction : fusion_instructions_) {
ss << fusion_instruction->ToString() << "\n";
for (const auto& fusion_instruction : MakeInstructionPostOrder()) {
ss << fusion_instruction.ToString() << "\n";
}
return ss.str();
}
Expand Down
12 changes: 10 additions & 2 deletions xla/hlo/utils/hlo_traversal.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,29 @@ bool IsOpcodeAnyOf(const HloInstruction* instr) {

namespace internal {

// An interface to abstract away the difference between single instruction
// fusion and fused computations.
// An interface to abstract away the difference between a single instruction
// and a fusion instruction with all it's (potentially nbested) computations.
class HloFusionInstructionAdaptor {
public:
virtual ~HloFusionInstructionAdaptor() = default;
// Returns true if the given 'instruction' is either the adapted instruction
// or contained in one of its nested computations.
virtual bool ContainsInstruction(const HloInstruction* instruction) const = 0;
// If it is a regular multi-output fusion, the order of the returned roots
// matches the order of the tuple elements of the tuple root of the fusion
// computation. We do not deduplicate fusion roots.
virtual absl::InlinedVector<HloInstructionAdaptor, 2> GetRoots() const = 0;
// Returns the operands of the adapted instruction.
virtual absl::InlinedVector<const HloInstruction*, 2> GetParameters()
const = 0;
// Returns the adapted instruction.
virtual const HloInstruction& FusionInstruction() const = 0;
// Returns the single instruction or the instructions of the (potentially
// nested) computations, in post order.
virtual absl::InlinedVector<HloInstructionAdaptor, 2>
MakeInstructionPostOrder() const = 0;
// Calls 'fn' the single instruction or all instructions in the (potentially
// nested) computations, in some order.
virtual void ForEach(
const std::function<void(HloInstructionAdaptor)>& fn) const = 0;
virtual std::string ToString() const = 0;
Expand Down
31 changes: 31 additions & 0 deletions xla/hlo/utils/hlo_traversal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,37 @@ TEST_F(HloTraversalTest, AdaptorUsers) {
EXPECT_TRUE(neg.GetUsers().empty());
}

TEST_F(HloTraversalTest, NestedFusionIsTraversedCorrectly) {
auto module = ParseAndReturnVerifiedModule(
R"(
inner {
p0 = f32[] parameter(0)
ROOT mul = f32[] multiply(p0, p0)
}
outer {
p0 = f32[] parameter(0)
inner = f32[] fusion(p0), kind=kLoop, calls=inner
ROOT neg = f32[] negate(inner)
}
ENTRY entry {
p0 = f32[] parameter(0)
ROOT fusion = f32[] fusion(p0), kind=kLoop, calls=outer
}
)")
.value();

auto fusion_adaptor = HloFusionAdaptor::ForInstruction(
module->entry_computation()->root_instruction());

HloInstructionAdaptor negate_instruction = fusion_adaptor->GetRoots()[0];

EXPECT_THAT(negate_instruction, InstructionAdaptorName("neg"));
EXPECT_THAT(negate_instruction.GetOperands(),
ElementsAre(InstructionAdaptorName("mul")));
}

TEST_F(HloTraversalTest, TraverseFusionConsumerFirst) {
auto module = ParseAndReturnVerifiedModule(kTestModule).value();
std::vector<std::string> visited_nodes;
Expand Down
3 changes: 2 additions & 1 deletion xla/service/gpu/transforms/nest_gemm_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,14 @@ ENTRY entry_computation {
kind=kCustom, calls=dot, backend_config={
"fusion_backend_config":{
"kind":"__triton_gemm","triton_gemm_config":{
"block_m":"16","block_n":"16","block_k":"32",
"block_m":"4","block_n":"16","block_k":"128",
"split_k":"1","num_stages":"1","num_warps":"4","num_ctas":"1"
}
}
}
}
)";
// Note: block sizes were 16,16,32, but that now fails to satisfy constraints.
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(hlo));
EXPECT_THAT(NestGemmFusion().Run(module.get()), IsOkAndHolds(true));
TF_ASSERT_OK(verifier().Run(module.get()).status());
Expand Down

0 comments on commit cdd12c2

Please sign in to comment.