Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Add instruction schedule loop boundary guard hints #5163

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

ravil-mobile
Copy link
Contributor

@ravil-mobile ravil-mobile commented Nov 15, 2024

Extended AMDGPU instruction scheduling.

  • The introduced source code changes add sched.barriers at the beginning and at the end of each scf.For op (called guards). The guards prevent moves of instructions from basic block adjacent to the bodies for for-loops. According to test results, it results in increase performance for the FA-like kernels due to a reduction of VGPRs spilling.

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@ravil-mobile ravil-mobile force-pushed the ravil/fa-sched branch 3 times, most recently from 2f4d48f to 681a1f7 Compare November 18, 2024 11:28
@ravil-mobile ravil-mobile marked this pull request as ready for review November 18, 2024 11:47
@ravil-mobile
Copy link
Contributor Author

@sjw, @antiagainst @zhanglx13, could you, please, review the code?

@zhanglx13 zhanglx13 marked this pull request as draft November 18, 2024 20:46
@@ -64,7 +64,7 @@ class HIPOptions:
# Kernel library. Note, this variant requires the use of buffer load/store ops
# and a special software pipelining style - i.e., 1x LDS and 1x register
# prefetch buffers for each GEMM tile.
instruction_sched_variant: str = 'none'
instruction_sched_variant: str = 'guard'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update the comment in the above to:

  • Move the "none" variant from L51 to L57
  • Add a new bullet for "guard". Also "guard" is too generic; what about naming it as "loop_boundary_guard" or something.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move the "none" variant from L51 to L57

Done

Add a new bullet for "guard". Also "guard" is too generic; what about naming it as "loop_boundary_guard" or something.

Moved to a dedicated op and pass

third_party/amd/include/TritonAMDGPUToLLVM/Passes.td Outdated Show resolved Hide resolved
case SchedulingType::LLVM_IGLP_0:
case SchedulingType::LLVM_IGLP_1:
case triton::amdgpu::SchedHint::llvm_iglp_0:
[[fallthrough]];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what value these [[fallthrough]] marker provides.. It's pretty clear for C/C++ folks it should be fall through. Having it just takes up multiple lines of code for maintenance burden. Can you drop them please?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[[fallthrough]] attribute explicitly indicates intended fallthrough for better code clarity and to suppress warnings from compilers. Are you sure you want to drop it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

triton::amdgpu::SchedHint schedulingType =
instructionSchedHint.getSchedVariant();
if ((this->numStages < 2) &&
(schedulingType != triton::amdgpu::SchedHint::guard)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the case for all other variants? (at least I see we don't need to emit the debug in the below for none case)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SchedHint::guard has been removed from this pass

mod.walk([this, ctx](scf::ForOp forOp) {
auto maybeSchedHint = triton::amdgpu::symbolizeSchedHint(this->variant);
if (!maybeSchedHint) {
LDBG("Skipping instruction scheduling: unknown scheduling hint.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also print out the provided hint variant? For error message, be explicit and verbose--it could save you some debugging time down the road. :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 361 to 376

namespace {
void getAllInnerForOps(scf::ForOp forOp,
llvm::SetVector<scf::ForOp> &innermostForOps) {
bool found = false;
forOp.getBody()->walk([&found, &innermostForOps](scf::ForOp innerForOp) {
getAllInnerForOps(innerForOp, innermostForOps);
found = true;
});
if (!found)
innermostForOps.insert(forOp);
}
} // namespace

namespace mlir::triton::AMD {
llvm::SetVector<scf::ForOp> getAllInnerForOps(mlir::triton::FuncOp funcOp) {
llvm::SetVector<scf::ForOp> innermostForOps{};
funcOp->walk(
[&](scf::ForOp forOp) { ::getAllInnerForOps(forOp, innermostForOps); });
return innermostForOps;
}
} // namespace mlir::triton::AMD
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation is more complicated than I'd expect. I think something like the following would do?

SmallVector<scf::ForOp> getLeafForOps(triton::FuncOp funcOp) {
    SmallVector<scf::ForOp> allOps;
    funcOp->walk([&](scf::ForOp forOp) { allOps.push_back(forOp); });

    SmallVector<scf::ForOp> leafOps;
    for (scf::ForOp forOp : allOps) {
      auto r = forOp->walk([](scf::ForOp) { return WalkResult::interrupt(); });
      if (!r.wasInterrupted()) leafOps.push_back(forOp);
    }
    return leafOps;
  }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antiagainst, totally agree with your implementation. It is more concise and it is doing the same. Thanks!
Done

return;
}

mod.walk([this](triton::FuncOp funcOp) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm confused why we run this together with the logic at L537. They are disjoint effectively. Can you organize this together with L537 better?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@antiagainst @zhanglx13

Instruction Scheduling Hint was design for a single tt.DotOp inside a scf.ForOp. It is complicated right now to extend it to support multiple tt.DotOp in a single region. There may be cases where a single tt.Load refers to multiple tt.DotOps (the same may be regarding ttg.LocalLoad and ttg.LocalStore). In these cases, it is not 100% clear to which hint the corresponding ds_reads/ds_writes/global_loads/buffer_loads need to be attached for interleaving. Probably, ds_reads/global_loads/buffer_loads need to be attached to the first hint (which refers to the first tt.DotOp in the region) and ds_writes to the last one. However, there may be other computations in between any 2 tt.DotOp which may also involve load/store operation - e.g., tt.Load. As @sjw36 mentioned in our chat, we probably need to move to a proper DAG approach.

I'd like to point out that I though we urgently need guard option for FA-like kernels; this is the main goal of this PR.

this->variant = std::move(variant.str());
}

void guardFlashAttentionLikeProblems(triton::FuncOp funcOp) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This disregards the developer choice and forcefully set for attention. Why is that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

guardFlashAttentionLikeProblems has been removed in the new implementation

LDBG("skipping `local-prefetch` scheduling given it needs `buffer_load` "
"instructions");
LDBG("skipping `local_prefetch` scheduling given it needs `buffer_load` "
"instructions.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I always want to ask, why the local_prefetch scheduling needs a specific type of instruction? In this particular case, why global_load does not work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

global_loads involve address calculations (VALU instructions) and it has loop-carried dependencies. It turns out this messes up scheduling imposed by LLVM intrinsic calls - i.e., the compiler backend gets confused. In short, the CK-like scheduling works only with buffer_load instructions. Based on my experience, it is not worth it to apply local_prefetch to a GEMM variant which involve global_loads

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so buffer_load does address update with SALU instructions?

the compiler backend gets confused

How confused? Does it place those VALU's in the wrong place so that their latencies are exposed?

@zhanglx13
Copy link
Collaborator

The PR title is misleading. We don't need anything special for flash-attention like kernels.
All we need is to add a new sched_variant, in this case "guard" or some better name, so that

  1. It can be set as kernel arg
  2. If set, we insert sched.barrier at loop boundaries if the loop contains at least one dotOp.

@antiagainst antiagainst changed the title [AMD] Added instr.sched guards for the FA-like kernels [AMD] Add instruction schedule loop boundary guard hints Nov 19, 2024
@ravil-mobile
Copy link
Contributor Author

The PR title is misleading. We don't need anything special for flash-attention like kernels. All we need is to add a new sched_variant, in this case "guard" or some better name, so that

  1. It can be set as kernel arg
  2. If set, we insert sched.barrier at loop boundaries if the loop contains at least one dotOp.

@zhanglx13 @zhanglx13

I'd like to propose something different - i.e., a new instruction in our dialect which is dedicated to instruction scheduling guards (triton::amdgpu::InstructionSchedGuard). We would need to introduce a dedicated conversion pass as well. This allows us to put guards independent on the number of tt.DotOps in a region. Moreover, the code will become more readable because of the separation of concerns between scheduling and guarding logic.

  1. Some instruction scheduling variants (e.g., local_prefetch) can add InstructionSchedGuard to a region which is going to be lowered to the corresponding LLVM intrinsic calls later
  2. A user can set a kernel argument to guard all scf.ForOps
  3. A user can set guards to a specific for-loop via a dedicated tt.range parameter

# Kernel library. Note, this variant requires the use of buffer load/store ops
# and a special software pipelining style - i.e., 1x LDS and 1x register
# prefetch buffers for each GEMM tile.
instruction_sched_variant: str = 'none'

# The following option prevents moves of instructions from the regions where they are defined.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure we want to introduce yet another separate knob here. The combinational effect is not something I'm convinced that we need right now. Why can't this be just another option like the above? Even this is experimental right now, I'd prefer to trim down the amount of different variants to only leave proven useful ones, not a combational of lots different configurations going unbound.

If you need to turn on the boundary guard for some existing variants, it can be achieved with patterns. For example, you can have two patterns, one LowerBoundaryGuardToSchedBarrier, and other RemoveBoundaryGuard. If the variant is local_prefetch/etc., then you pull in the first pattern; otherwise pull in the second. This keeps the lowering patterns separate concerns while making them composable.

def TritonAMDGPUInsertInstructionSchedHints : Pass<"triton-amdgpu-insert-instruction-sched-hints", "mlir::ModuleOp"> {
let summary = "Insert instruction scheduling hints after the dot ops in the main loop";
let constructor = "mlir::triton::createTritonAMDGPUInsertInstructionSchedHintsPass()";
def TritonAMDGPUInsertInstructionControlLogic : Pass<"triton-amdgpu-insert-instruction-control-logic", "mlir::ModuleOp"> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following the previous comments of avoiding combinational knobs, here I'm not sure we want a proliferation of passses that just handle one op. They should be folded into existing passes given they are functionality closely related. You can use greedy pattern rewriter for simple 1:1 op rewrites.

@@ -491,14 +469,15 @@ struct TritonAMDGPULowerInstructionSchedHints
ConversionTarget target(*ctx);
target.addLegalDialect<LLVM::LLVMDialect>();
target.addIllegalOp<triton::amdgpu::InstructionSchedHint>();
target.addLegalOp<triton::amdgpu::InstructionSchedGuard>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this whole pass should be restructured. What about something like this: define a few patterns for lowering sched hint op, LowerSchedHintIntoIGLP, LowerSchedHintIntoSchedGroupBarrier, RemoveSchedHint, etc. And also as I commented before, LowerBoundaryGuardIntoSchedBarrier, RemoveBoundaryGuard. Each pattern does no excessive checks--they just do mechanical op conversion as the name indicates. In this runOnOperation, check the sched variant chosen and pull in patterns accordingly to run greey pattern rewriter. This gives us a clear structure--the switch/conditions are in the pass, with each pattern dedicated to its own task. Easy to grow/shrink to avoid massive/branchy patterns.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants