Skip to content

Commit

Permalink
[GPU] Add pattern to fuse tensor.collapse_shape into forall producer
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 committed Dec 4, 2024
1 parent 9f8aad8 commit 102bc5a
Show file tree
Hide file tree
Showing 10 changed files with 569 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
Expand All @@ -22,7 +20,6 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-gpu-fuse-and-hoist-parallel-loops"
Expand Down Expand Up @@ -325,6 +322,24 @@ struct FuseTilableForallConsumers final
}
};

struct FuseCollapseShapeConsumers final
: OpRewritePattern<tensor::CollapseShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
PatternRewriter &rewriter) const override {
auto forallOp = collapseOp.getSrc().getDefiningOp<scf::ForallOp>();
if (!forallOp) {
return rewriter.notifyMatchFailure(collapseOp, "No forall op producer");
}

if (failed(fuseCollapseShapeIntoProducerForall(rewriter, forallOp,
collapseOp))) {
return failure();
}
return success();
}
};

void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();

Expand Down Expand Up @@ -369,6 +384,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
patterns.add<FuseTilableDestinationProducers>(context);
patterns.add<FuseUnitLoopDestination>(context);
patterns.add<FuseTilableForallConsumers>(context);
patterns.add<FuseCollapseShapeConsumers>(context);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,57 @@ func.func @no_fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) -
// CHECK: scf.forall.in_parallel
// CHECK: linalg.add
// CHECK: return

// -----

#map = affine_map<(d0) -> (d0 * 2)>
func.func @no_fuse_non_contiguous_collapse_shape(%arg0: tensor<8x8xf32>) -> tensor<64xf32> {
%0 = tensor.empty() : tensor<8x8xf32>
%1 = scf.forall (%arg1) in (4) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) {
%2 = affine.apply #map(%arg1)
%extracted_slice = tensor.extract_slice %arg0[%2, 0] [2, 7] [1, 1] : tensor<8x8xf32> to tensor<2x7xf32>
%extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [2, 7] [1, 1] : tensor<8x8xf32> to tensor<2x7xf32>
%3 = linalg.copy ins(%extracted_slice : tensor<2x7xf32>) outs(%extracted_slice_0 : tensor<2x7xf32>) -> tensor<2x7xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg2[%2, 0] [2, 7] [1, 1] : tensor<2x7xf32> into tensor<8x8xf32>
}
} {mapping = [#gpu.thread<x>]}
%collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32>
return %collapsed : tensor<64xf32>
}

// CHECK-LABEL: func @no_fuse_non_contiguous_collapse_shape
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<8x8xf32>) {
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2x7xf32> into tensor<8x8xf32>
// CHECK: }
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]]
// CHECK: return %[[COLLAPSE]]

// -----

#map = affine_map<(d0) -> (d0 * 2)>
func.func @no_fuse_collapse_shape_rank_reduced(%arg0: tensor<8x8xf32>) -> tensor<64xf32> {
%0 = tensor.empty() : tensor<8x8xf32>
%1 = scf.forall (%arg1) in (8) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) {
%2 = affine.apply #map(%arg1)
%extracted_slice = tensor.extract_slice %arg0[%2, 0] [1, 8] [1, 1] : tensor<8x8xf32> to tensor<8xf32>
%extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [1, 8] [1, 1] : tensor<8x8xf32> to tensor<8xf32>
%3 = linalg.copy ins(%extracted_slice : tensor<8xf32>) outs(%extracted_slice_0 : tensor<8xf32>) -> tensor<8xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg2[%2, 0] [1, 8] [1, 1] : tensor<8xf32> into tensor<8x8xf32>
}
} {mapping = [#gpu.thread<x>]}
%collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32>
return %collapsed : tensor<64xf32>
}

// CHECK-LABEL: func @no_fuse_collapse_shape_rank_reduced
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<8x8xf32>) {
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<8xf32> into tensor<8x8xf32>
// CHECK: }
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]]
// CHECK: return %[[COLLAPSE]]
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,54 @@ void transform_dialect::FuseForallOp::getEffects(
transform::modifiesPayload(effects);
}

//===---------------------------------------------------------------------===//
// FuseCollapseShapeWithForallOp
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform_dialect::FuseCollapseShapeWithForallOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto producers = state.getPayloadOps(getProducer());
auto consumers = state.getPayloadOps(getConsumer());

int64_t numProducers = llvm::range_size(producers);
int64_t numConsumers = llvm::range_size(consumers);
if (numProducers != 1 || numConsumers != 1) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"More than one producer or consumer");
}

auto producer = dyn_cast<scf::ForallOp>(*producers.begin());
if (!producer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-forall producer");
}
auto consumer = dyn_cast<tensor::CollapseShapeOp>(*consumers.begin());
if (!consumer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-collapse_shape consumer");
}

FailureOr<scf::ForallOp> fusedForallOp =
GPU::fuseCollapseShapeIntoProducerForall(rewriter, producer, consumer);
if (failed(fusedForallOp)) {
return mlir::emitSilenceableFailure(state.getTopLevel(),
"failed to fuse collapse_shape op");
}

results.set(getOperation()->getOpResult(0), {fusedForallOp.value()});
return DiagnosedSilenceableFailure::success();
}

void transform_dialect::FuseCollapseShapeWithForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getProducerMutable(), effects);
transform::consumesHandle(getConsumerMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}

} // namespace mlir::iree_compiler::IREE

void mlir::iree_compiler::registerTransformDialectIREEGPUExtension(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,4 +228,38 @@ def FuseForallOp : Op<Transform_Dialect, "iree.fuse_forall",
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

def FuseCollapseShapeWithForallOp : Op<Transform_Dialect, "iree.fuse_collapse_shape_with_forall",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Fuses a consumer tensor.collapse_shape op into a producer scf.forall op.
The users of the block argument for the corresponding forall output operand
should be only a tensor.parallel_insert_slice op, and tensor.extract_slice
ops that extract an equivalent subset. After the fusion, the output of the
forall will be collapsed, and all users of this block arg will also be
collapsed. Additional tensor.expand_shape ops will be inserted after any
tensor.extract_slice users inside the forall so that types match. Similarly,
a tensor.collapse_shape will be inserted before the
tensor.parallel_insert_slice.

#### Return modes
Emits a definite failure if either the producer is not an scf.forall op or
if the consumer is not a tensor.collapse_shape op.
}];

let arguments = (
ins TransformHandleTypeInterface:$producer,
TransformHandleTypeInterface:$consumer
);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat = [{
$consumer `into` $producer attr-dict
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ iree_lit_test_suite(
"drop_multi_mma_unit_dims.mlir",
"lower_multi_mma.mlir",
"lower_vector_barrier.mlir",
"transform_fuse_collapse_shape_with_forall.mlir",
"transform_fuse_forall.mlir",
"transform_lower_barrier_region.mlir",
"vectorize_iree_gpu_ops.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ iree_lit_test_suite(
"drop_multi_mma_unit_dims.mlir"
"lower_multi_mma.mlir"
"lower_vector_barrier.mlir"
"transform_fuse_collapse_shape_with_forall.mlir"
"transform_fuse_forall.mlir"
"transform_lower_barrier_region.mlir"
"unroll_multi_mma.mlir"
Expand Down
Loading

0 comments on commit 102bc5a

Please sign in to comment.