Skip to content

Commit

Permalink
[LLVMGPU] Add pass to distribute undistributed copies to threads (#19715
Browse files Browse the repository at this point in the history
)

This pass walks a function and distributes any memref copies not present
within an scf.forall distributed to threads/warps/lanes. This pass
assumes that implicit distribution (a la gpu.thread_id) is not present.
  • Loading branch information
qedawkins authored Jan 17, 2025
1 parent dde5992 commit c1cc4cc
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ iree_compiler_cc_library(
"GPUCombineValueBarriers.cpp",
"GPUCreateFastSlowPath.cpp",
"GPUDistribute.cpp",
"GPUDistributeCopyUsingForall.cpp",
"GPUDistributeForall.cpp",
"GPUDistributeScfFor.cpp",
"GPUDistributeSharedMemoryCopy.cpp",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ iree_cc_library(
"GPUCombineValueBarriers.cpp"
"GPUCreateFastSlowPath.cpp"
"GPUDistribute.cpp"
"GPUDistributeCopyUsingForall.cpp"
"GPUDistributeForall.cpp"
"GPUDistributeScfFor.cpp"
"GPUDistributeSharedMemoryCopy.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"

#define DEBUG_TYPE "iree-codegen-gpu-distribute-shared-memory-copy"

namespace mlir::iree_compiler {

#define GEN_PASS_DEF_GPUDISTRIBUTECOPYUSINGFORALLPASS
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"

namespace {
//====---------------------------------------------------------------------===//
// Pass to lower workgroup memory copy to distibuted
// transfer_read/transfer_write ops.
//====---------------------------------------------------------------------===//

// For optimal performance we always want to copy 128 bits
static constexpr int kPreferredCopyNumBits = 128;

// Moves the copy into a single threaded forall.
static void distributeCopyToSingleThread(RewriterBase &rewriter,
memref::CopyOp copy) {
SmallVector<Attribute> mapping = {gpu::GPUThreadMappingAttr::get(
rewriter.getContext(), gpu::MappingId::LinearDim0)};
scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
copy.getLoc(), ArrayRef<OpFoldResult>{rewriter.getIndexAttr(0)},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)},
ArrayRef<OpFoldResult>{rewriter.getIndexAttr(1)},
/*outputs=*/ValueRange(), /*mapping=*/rewriter.getArrayAttr(mapping));
rewriter.moveOpBefore(copy, newForallOp.getBody(),
newForallOp.getBody()->begin());
}

/// Distributes a copy with a thread mapping.
static void distributeCopyToThreads(RewriterBase &rewriter, memref::CopyOp copy,
ArrayRef<OpFoldResult> tileSizes) {
int64_t rank = tileSizes.size();
assert(rank == copy.getTarget().getType().getRank() &&
"tile size and copy rank mismatch");
if (rank == 0) {
distributeCopyToSingleThread(rewriter, copy);
return;
}

Location loc = copy.getLoc();
MLIRContext *context = rewriter.getContext();

SmallVector<OpFoldResult> lowerBounds(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> upperBounds =
memref::getMixedSizes(rewriter, loc, copy.getSource());

SmallVector<Attribute> mapping;
int idx = 0;
for (int64_t i = 0, e = rank; i < e; ++i) {
unsigned mappingId =
static_cast<unsigned>(gpu::MappingId::LinearDim0) + idx++;
mapping.push_back(gpu::GPUThreadMappingAttr::get(
context, static_cast<gpu::MappingId>(mappingId)));
}
mapping = llvm::to_vector(llvm::reverse(mapping));

scf::ForallOp newForallOp = rewriter.create<scf::ForallOp>(
copy.getLoc(), lowerBounds, upperBounds, tileSizes,
/*outputs=*/ValueRange(), /*mapping=*/rewriter.getArrayAttr(mapping));

rewriter.setInsertionPointToStart(newForallOp.getBody());

AffineExpr d0, d1, d2;
bindDims(context, d0, d1, d2);
SmallVector<OpFoldResult> sizes;
AffineMap minMap =
AffineMap::get(/*dimCount=*/3, /*symbolCount=*/0, {d0, d1 - d2}, context);
for (auto [ub, tileSize, iterator] : llvm::zip_equal(
upperBounds, tileSizes, newForallOp.getInductionVars())) {
std::optional<int64_t> staticUb = getConstantIntValue(ub);
std::optional<int64_t> staticTileSize = getConstantIntValue(tileSize);
if ((staticUb && staticTileSize &&
staticUb.value() % staticTileSize.value() == 0) ||
(staticTileSize.value_or(0) == 1)) {
sizes.push_back(tileSize);
} else {
sizes.push_back(
rewriter
.create<affine::AffineMinOp>(
loc, rewriter.getIndexType(), minMap,
ValueRange{
getValueOrCreateConstantIndexOp(rewriter, loc, tileSize),
getValueOrCreateConstantIndexOp(rewriter, loc, ub),
iterator})
.getResult());
}
}

SmallVector<OpFoldResult> offsets =
getAsOpFoldResult(newForallOp.getInductionVars());
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
Value sourceTile = rewriter.create<memref::SubViewOp>(
loc, copy.getSource(), offsets, sizes, strides);
Value targetTile = rewriter.create<memref::SubViewOp>(
loc, copy.getTarget(), offsets, sizes, strides);
rewriter.replaceOpWithNewOp<memref::CopyOp>(copy, sourceTile, targetTile);
}

static SmallVector<OpFoldResult> getCopyTileSizes(Builder &b,
memref::CopyOp copy) {
int64_t rank = copy.getTarget().getType().getRank();
if (rank == 0) {
return {};
}

SmallVector<OpFoldResult> tileSizes(rank - 1, b.getIndexAttr(1));
int64_t elementBitWidth = llvm::cast<MemRefType>(copy.getTarget().getType())
.getElementTypeBitWidth();
tileSizes.push_back(b.getIndexAttr(kPreferredCopyNumBits / elementBitWidth));
return tileSizes;
}

} // namespace

namespace {
struct GPUDistributeCopyUsingForallPass final
: impl::GPUDistributeCopyUsingForallPassBase<
GPUDistributeCopyUsingForallPass> {
void runOnOperation() override {
MLIRContext *context = &getContext();
auto funcOp = getOperation();

SmallVector<memref::CopyOp> copies;

// Walk in PreOrder so that parent operations are visited before children,
// thus allowing all operations contained within thread/warp/lane foralls
// to be skipped.
funcOp.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (auto forallOp = dyn_cast<scf::ForallOp>(op)) {
// Skip ops contained within forall ops with thread/warp/lane mappings.
if (forallOpHasMappingType<IREE::GPU::LaneIdAttr,
gpu::GPUWarpMappingAttr,
gpu::GPUThreadMappingAttr>(forallOp)) {
return WalkResult::skip();
}
}
if (auto copy = dyn_cast<memref::CopyOp>(op)) {
copies.push_back(copy);
}
return WalkResult::advance();
});

IRRewriter rewriter(context);
for (auto copy : copies) {
rewriter.setInsertionPoint(copy);
SmallVector<OpFoldResult> tileSizes = getCopyTileSizes(rewriter, copy);
distributeCopyToThreads(rewriter, copy, tileSizes);
}
}
};
} // namespace
} // namespace mlir::iree_compiler
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/GPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def GPUCreateFastSlowPathPass :
let dependentDialects = ["::mlir::scf::SCFDialect"];
}

def GPUDistributeCopyUsingForallPass :
InterfacePass<"iree-codegen-gpu-distribute-copy-using-forall", "mlir::FunctionOpInterface"> {
let summary = "Pass to distribute copies to threads.";
let dependentDialects = [
"::mlir::affine::AffineDialect", "::mlir::gpu::GPUDialect", "::mlir::scf::SCFDialect"
];
}

def GPUDistributeForallPass :
InterfacePass<"iree-codegen-gpu-distribute-forall", "mlir::FunctionOpInterface"> {
let summary = "Pass to distribute scf.forall ops.";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_lit_test_suite(
"gpu_check_resource_usage.mlir",
"gpu_create_fast_slow_path.mlir",
"gpu_distribute.mlir",
"gpu_distribute_copy_using_forall.mlir",
"gpu_distribute_forall.mlir",
"gpu_distribute_scf_for.mlir",
"gpu_distribute_shared_memory.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
"gpu_combine_value_barriers.mlir"
"gpu_create_fast_slow_path.mlir"
"gpu_distribute.mlir"
"gpu_distribute_copy_using_forall.mlir"
"gpu_distribute_forall.mlir"
"gpu_distribute_scf_for.mlir"
"gpu_distribute_shared_memory.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(func.func(iree-codegen-gpu-distribute-copy-using-forall))' %s | FileCheck %s

func.func @static_copy(%src : memref<56x32xf32>, %target : memref<56x32xf32>) {
memref.copy %src, %target : memref<56x32xf32> to memref<56x32xf32>
return
}

// CHECK-LABEL: func.func @static_copy
// CHECK-SAME: (%[[SRC:.+]]: memref<56x32xf32>, %[[TARGET:.+]]: memref<56x32xf32>)

// CHECK: scf.forall (%[[IV0:[A-Za-z0-9]+]], %[[IV1:[A-Za-z0-9]+]]) = (0, 0) to (56, 32) step (1, 4) {
// CHECK-DAG: %[[SRC_SUBVIEW:.+]] = memref.subview %[[SRC]][%[[IV0]], %[[IV1]]] [1, 4] [1, 1]
// CHECK-DAG: %[[TARGET_SUBVIEW:.+]] = memref.subview %[[TARGET]][%[[IV0]], %[[IV1]]] [1, 4] [1, 1]
// CHECK: memref.copy %[[SRC_SUBVIEW]], %[[TARGET_SUBVIEW]]
// CHECK: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// -----

func.func @unaligned_copy(%src : memref<56x31xf32>, %target : memref<56x31xf32>) {
memref.copy %src, %target : memref<56x31xf32> to memref<56x31xf32>
return
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
// CHECK-LABEL: func.func @unaligned_copy
// CHECK-SAME: (%[[SRC:.+]]: memref<56x31xf32>, %[[TARGET:.+]]: memref<56x31xf32>)

// CHECK: scf.forall (%[[IV0:[A-Za-z0-9]+]], %[[IV1:[A-Za-z0-9]+]]) = (0, 0) to (56, 31) step (1, 4) {
// CHECK: %[[MIN:.+]] = affine.min #[[$MAP]](%c4, %c31, %[[IV1]])
// CHECK-DAG: %[[SRC_SUBVIEW:.+]] = memref.subview %[[SRC]][%[[IV0]], %[[IV1]]] [1, %[[MIN]]]
// CHECK-DAG: %[[TARGET_SUBVIEW:.+]] = memref.subview %[[TARGET]][%[[IV0]], %[[IV1]]] [1, %[[MIN]]]
// CHECK: memref.copy %[[SRC_SUBVIEW]], %[[TARGET_SUBVIEW]]
// CHECK: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// -----

func.func @dynamic_copy(%src : memref<?x?xf32>, %target : memref<?x?xf32>) {
memref.copy %src, %target : memref<?x?xf32> to memref<?x?xf32>
return
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 - d2)>
// CHECK-LABEL: func.func @dynamic_copy
// CHECK-SAME: (%[[SRC:.+]]: memref<?x?xf32>, %[[TARGET:.+]]: memref<?x?xf32>)

// CHECK-DAG: %[[D0:.+]] = memref.dim %[[SRC]], %c0 : memref<?x?xf32>
// CHECK-DAG: %[[D1:.+]] = memref.dim %[[SRC]], %c1 : memref<?x?xf32>
// CHECK: scf.forall (%[[IV0:[A-Za-z0-9]+]], %[[IV1:[A-Za-z0-9]+]]) = (0, 0) to (%[[D0]], %[[D1]]) step (1, 4) {
// CHECK: %[[MIN:.+]] = affine.min #[[$MAP]](%c4, %[[D1]], %[[IV1]])
// CHECK-DAG: %[[SRC_SUBVIEW:.+]] = memref.subview %[[SRC]][%[[IV0]], %[[IV1]]] [1, %[[MIN]]]
// CHECK-DAG: %[[TARGET_SUBVIEW:.+]] = memref.subview %[[TARGET]][%[[IV0]], %[[IV1]]] [1, %[[MIN]]]
// CHECK: memref.copy %[[SRC_SUBVIEW]], %[[TARGET_SUBVIEW]]
// CHECK: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// -----

func.func @f16_copy(%src : memref<56x32xf16>, %target : memref<56x32xf16>) {
memref.copy %src, %target : memref<56x32xf16> to memref<56x32xf16>
return
}

// CHECK-LABEL: func.func @f16_copy
// CHECK-SAME: (%[[SRC:.+]]: memref<56x32xf16>, %[[TARGET:.+]]: memref<56x32xf16>)

// CHECK: scf.forall (%[[IV0:[A-Za-z0-9]+]], %[[IV1:[A-Za-z0-9]+]]) = (0, 0) to (56, 32) step (1, 8) {
// CHECK-DAG: %[[SRC_SUBVIEW:.+]] = memref.subview %[[SRC]][%[[IV0]], %[[IV1]]] [1, 8]
// CHECK-DAG: %[[TARGET_SUBVIEW:.+]] = memref.subview %[[TARGET]][%[[IV0]], %[[IV1]]] [1, 8]
// CHECK: memref.copy %[[SRC_SUBVIEW]], %[[TARGET_SUBVIEW]]
// CHECK: mapping = [#gpu.thread<linear_dim_1>, #gpu.thread<linear_dim_0>]

// -----

func.func @rank_0_copy(%src : memref<f32>, %target : memref<f32>) {
memref.copy %src, %target : memref<f32> to memref<f32>
return
}

// CHECK-LABEL: func.func @rank_0_copy
// CHECK-SAME: (%[[SRC:.+]]: memref<f32>, %[[TARGET:.+]]: memref<f32>)

// CHECK: scf.forall (%{{.*}}) in (1) {
// CHECK: memref.copy %[[SRC]], %[[TARGET]]
// CHECK: mapping = [#gpu.thread<linear_dim_0>]

// -----

func.func @already_distributed_copy(%src : memref<56x32xf32>, %target : memref<56x32xf32>) {
scf.forall (%arg2) in (1) {
memref.copy %src, %target : memref<56x32xf32> to memref<56x32xf32>
} {mapping = [#gpu.thread<linear_dim_0>]}
return
}

// CHECK-LABEL: func.func @already_distributed_copy
// CHECK-SAME: (%[[SRC:.+]]: memref<56x32xf32>, %[[TARGET:.+]]: memref<56x32xf32>)

// CHECK: scf.forall (%{{.*}}) in (1) {
// CHECK: memref.copy %[[SRC]], %[[TARGET]]
// CHECK: mapping = [#gpu.thread<linear_dim_0>]
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
addGPUBufferizePasses(funcPassManager);

// Step 8. Resolve remaining parallel loops.
funcPassManager.addPass(createGPUDistributeCopyUsingForallPass());
funcPassManager.addPass(iree_compiler::createNormalizeLoopBoundsPass(
NormalizeLoopBoundsPassOptions{/*normalizeFor=*/false,
/*normalizeForall=*/true}));
Expand Down

0 comments on commit c1cc4cc

Please sign in to comment.