Skip to content

Commit

Permalink
[Codegen] Sprinkle in PropagateDispatchSizeBounds passes
Browse files Browse the repository at this point in the history
Since the various tiling and distribution don't know how to set the
upper bounds on workitem or workgroup IDs - even if that information
is known from context, we use the PropagateDispatchSizeBounds pass to
add that information before passes that use it.

The mani passes that use this information are those that use the
ValueBoundsOpInterface - that is, loop invariant code motion, some
vectorization code, and, in an upcoming commit,
RemoveSingleIterationLoop.

These calls can be removed in the future, but they'll do for now.
  • Loading branch information
krzysz00 committed Jan 18, 2025
1 parent 6c59ac1 commit 6d06506
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ static std::pair<Value, Value> makeTransposedIds(Location loc, OpBuilder b,
/// Returns the workgroup counts along the X and Y dimensions. These will be
/// constants when static in the corresponding `hal.executable.export` op.
static std::pair<Value, Value>
getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp) {
getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp,
std::optional<APInt> xBound, std::optional<APInt> yBound) {
Location loc = funcOp.getLoc();
SmallVector<int64_t> workgroupCounts = getStaticNumWorkgroups(funcOp);
bool isStaticWgCount = llvm::none_of(workgroupCounts, ShapedType::isDynamic);
Expand All @@ -62,9 +63,9 @@ getWorkgroupCountsXY(OpBuilder &builder, FunctionOpInterface funcOp) {

LLVM_DEBUG(llvm::dbgs() << "Using dynamic workgroup counts\n");
Value dynamicCountX =
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0);
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 0, xBound);
Value dynamicCountY =
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1);
builder.create<IREE::HAL::InterfaceWorkgroupCountOp>(loc, 1, yBound);
return {dynamicCountX, dynamicCountY};
}

Expand Down Expand Up @@ -100,11 +101,12 @@ reorderWorkgroupsInFunc(FunctionOpInterface funcOp,
// that to RAUW the old ones. This way we don't have to worry about the
// picking the exact insertion points that do not violate dominance between
// their defs and users.
Value workgroupIdX =
builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 0);
Value workgroupIdY =
builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(funcOp.getLoc(), 1);
auto [workgroupCntX, workgroupCntY] = getWorkgroupCountsXY(builder, funcOp);
Value workgroupIdX = builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(
funcOp.getLoc(), 0, oldXId.getUpperBound());
Value workgroupIdY = builder.create<IREE::HAL::InterfaceWorkgroupIDOp>(
funcOp.getLoc(), 1, oldYId.getUpperBound());
auto [workgroupCntX, workgroupCntY] = getWorkgroupCountsXY(
builder, funcOp, oldXId.getUpperBound(), oldYId.getUpperBound());
Value newWorkgroupIdX;
Value newWorkgroupIdY;
assert(strategy == ReorderWorkgroupsStrategy::Transpose &&
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ static void addTileAndDistributePasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createFuseTensorPadWithConsumerPass());
funcPassManager.addPass(createConcretizePadResultShapePass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
}

//===---------------------------------------------------------------------===//
Expand Down Expand Up @@ -447,6 +448,7 @@ void addMultiTilingExpertPassPipeline(OpPassManager &funcPassManager,
addCPUBufferizePasses(funcPassManager);

// Run IREE specific passes before vector lowering expert.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());

{
Expand Down Expand Up @@ -510,6 +512,7 @@ void addConvTileAndDecomposeExpertPassPipeline(
addCPUBufferizePasses(funcPassManager);

// Run IREE specific passes before vector lowering expert.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());

{
Expand Down
17 changes: 14 additions & 3 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ void addGPUVectorizationPassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createGPUDistributePass());

// Post bufferization optimizations.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
funcPassManager.addPass(createCanonicalizerPass());
Expand Down Expand Up @@ -439,6 +440,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createTileLargeTensorsPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
funcPassManager.addPass(IREE::GPU::createCombineBarrierRegionsPass());

Expand Down Expand Up @@ -468,6 +470,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCSEPass());

// Step 9. Remaining post-bufferization optimizations/lowerings.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(IREE::GPU::createLowerIREEGPUOpsPass());
funcPassManager.addPass(createUnrollAnnotatedLoopsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
Expand Down Expand Up @@ -524,6 +527,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createGPUDistributeScfForPass(options));

// Post bufferization optimizations.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
funcPassManager.addPass(memref::createFoldMemRefAliasOpsPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
Expand All @@ -544,6 +548,7 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
// Distribute linalg onto warps within the workgroup.
funcPassManager.addPass(
createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/true));
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
if (pipelineDepth > 1) {
funcPassManager.addPass(createGPUMultiBufferingPass(
Expand Down Expand Up @@ -589,6 +594,7 @@ void addGPUMatmulTensorCorePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCSEPass());

// Hoist loop invariant code to avoid pipelining it.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
// Pipeline memory operations.
GPUPipeliningPassOptions pipelieningOptions = {};
Expand All @@ -613,6 +619,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(
// Distribute linalg onto warps within the workgroup.
funcPassManager.addPass(
createLLVMGPUTileAndDistributePass(/*distributeToWarp=*/true));
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
if (pipelineDepth > 1) {
funcPassManager.addPass(createGPUMultiBufferingPass(
Expand Down Expand Up @@ -655,6 +662,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(
funcPassManager.addPass(createCSEPass());

// Hoist loop invariant code to avoid pipelining it.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createIREELoopInvariantCodeMotionPass());
// Pipeline memory operations.
GPUPipeliningPassOptions pipelieningOptions = {};
Expand Down Expand Up @@ -882,6 +890,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createGPUTileReductionPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());

// Linalg -> vector
{
Expand Down Expand Up @@ -949,6 +958,7 @@ void addGPUSimpleDistributePassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
}

Expand All @@ -965,6 +975,7 @@ void addGPUDefaultPassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createCSEPass());

addBufferizePasses(funcPassManager);
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
}

Expand All @@ -981,6 +992,7 @@ void addGPUBaseLoweringPassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(IREE::LinalgExt::createLinalgExtToLoopsPass());
funcPassManager.addPass(createMemrefCopyToLinalgPass());
funcPassManager.addPass(createConvertLinalgToLoopsPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
Expand All @@ -999,6 +1011,7 @@ addLowerAndOptimizeAddressComputationPasses(FunctionLikeNest &funcPassManager) {
.addPass(memref::createExpandOpsPass)
.addPass(memref::createFoldMemRefAliasOpsPass)
.addPass(memref::createExpandStridedMetadataPass)
.addPass(createPropagateDispatchSizeBoundsPass)
// Hoist loop invariant variables to give affine decomposition pass the
// right loop dependencies.
.addPass(createIREELoopInvariantCodeMotionPass)
Expand Down Expand Up @@ -1055,9 +1068,7 @@ static void addLowerToLLVMGPUPasses(OpPassManager &modulePassManager,
FunctionLikeNest funcPassManager(modulePassManager);
funcPassManager.addPass(createFoldTensorExtractOpPass)
.addPass(createLLVMGPUVectorLoweringPass)
.addPass(createExpandGPUOpsPass)
// Expose workitem and workgroup counts to range inference later.
.addPass(createPropagateDispatchSizeBoundsPass);
.addPass(createExpandGPUOpsPass);

// This pass needs to run before SCF -> CF.
addLowerAndOptimizeAddressComputationPasses(funcPassManager);
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/iree/compiler/Codegen/SPIRV/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ static void addLoopMaterializationPasses(OpPassManager &funcPassManager) {
funcPassManager.addPass(IREE::LinalgExt::createLinalgExtToLoopsPass());
funcPassManager.addPass(createMemrefCopyToLinalgPass());
funcPassManager.addPass(createConvertLinalgToLoopsPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
Expand Down Expand Up @@ -394,6 +395,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(
funcPassManager.addPass(
createSPIRVTileAndPromotePass(SPIRVTileAndPromotePassOptions{
/*promoteCMatrix=*/true, /*skipThreadLevel=*/true}));
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());
// Run canonicalization patterns to propagate constant shape sizes after
// removing trip-one loops.
Expand Down Expand Up @@ -421,6 +423,7 @@ void addSPIRVCooperativeMatrixVectorizePassPipeline(
funcPassManager.addPass(createGPUReduceBankConflictsPass(options));
}

funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
// Performs high-level n-D mechanical vectorization. This does not perform
// unrolling or lowering, which is done later.
{
Expand Down Expand Up @@ -513,6 +516,7 @@ void addSPIRVMatmulPromoteVectorizePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createGPUDistributeSharedMemoryCopyPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());

{
GPUReduceBankConflictsPassOptions options = {};
Expand All @@ -532,6 +536,7 @@ void addSPIRVMatmulPromoteVectorizePassPipeline(OpPassManager &funcPassManager,
funcPassManager.addPass(createForOpCanonicalizationPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createOptimizeVectorTransferPass());

// Hoist loop invariant code to avoid pipelining it.
Expand Down Expand Up @@ -560,6 +565,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &funcPassManager) {
funcPassManager.addPass(createGPUTileReductionPass());
funcPassManager.addPass(createCanonicalizerPass());
funcPassManager.addPass(createCSEPass());
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());

// Performs high-level n-D mechanical vectorization. This does not perform
// unrolling or lowering, which is done later.
Expand Down Expand Up @@ -588,6 +594,7 @@ void addSPIRVSubgroupReducePassPipeline(OpPassManager &funcPassManager) {

// Perform various vector-level cross-op optimizations like load-store
// forwarding, shape casting and casting op cancelling.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createOptimizeVectorTransferPass());

// Simplify the IR for vector distribution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ func.func @warp_reduction_dispatch() attributes {hal.executable.target = #execut

// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f16

// CHECK-DAG: %[[WGIDX:.+]] = hal.interface.workgroup.id[0] : index
// CHECK-DAG: %[[WGIDY:.+]] = hal.interface.workgroup.id[1] : index
// CHECK-DAG: %[[WGIDX:.+]] = hal.interface.workgroup.id[0] upper_bound 65535 : index
// CHECK-DAG: %[[WGIDY:.+]] = hal.interface.workgroup.id[1] upper_bound 65535 : index
// CHECK-DAG: %[[TIDX:.+]] = gpu.thread_id x

// CHECK-DAG: %[[SPAN0:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(0)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/VMVX/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ void addVMVXDefaultPassPipeline(OpPassManager &funcPassManager,
addCPUBufferizePasses(funcPassManager);

// Cleanup the IR that may now have unused loops.
funcPassManager.addPass(createPropagateDispatchSizeBoundsPass());
funcPassManager.addPass(createRemoveSingleIterationLoopPass());

// Convert buffer-level microkernels.
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3039,9 +3039,10 @@ class HAL_InterfaceWorkgroupOp<string mnemonic, list<Trait> traits = []>
let results = (outs HAL_Dim:$result);

let builders = [
OpBuilder<(ins "unsigned":$dim),
OpBuilder<(ins "unsigned":$dim, CArg<"std::optional<::llvm::APInt>", "std::nullopt">:$upper_bound),
[{
build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim), ::mlir::IntegerAttr{});
build($_builder, $_state, $_builder.getIndexType(), $_builder.getIndexAttr(dim),
upper_bound.has_value() ? $_builder.getIndexAttr(upper_bound->getSExtValue()) : ::mlir::IntegerAttr{});
}]>,
];

Expand Down

0 comments on commit 6d06506

Please sign in to comment.