Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[GPU] Add pattern to fuse tensor.collapse_shape into forall producer
Browse files Browse the repository at this point in the history
Signed-off-by: Max Dawkins <max.dawkins@gmail.com>
Max191 committed Jan 21, 2025
1 parent 4c0ba9c commit da68bc2
Showing 11 changed files with 593 additions and 8 deletions.
Original file line number Diff line number Diff line change
@@ -12,8 +12,6 @@
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -23,7 +21,6 @@
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-gpu-fuse-and-hoist-parallel-loops"
@@ -331,6 +328,24 @@ struct FuseTilableForallConsumers final
}
};

struct FuseCollapseShapeConsumers final
: OpRewritePattern<tensor::CollapseShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp,
PatternRewriter &rewriter) const override {
auto forallOp = collapseOp.getSrc().getDefiningOp<scf::ForallOp>();
if (!forallOp) {
return rewriter.notifyMatchFailure(collapseOp, "No forall op producer");
}

if (failed(fuseCollapseShapeIntoProducerForall(rewriter, forallOp,
collapseOp))) {
return failure();
}
return success();
}
};

void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
MLIRContext *context = &getContext();

@@ -375,6 +390,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() {
patterns.add<FuseTilableDestinationProducers>(context);
patterns.add<FuseUnitLoopDestination>(context);
patterns.add<FuseTilableForallConsumers>(context);
patterns.add<FuseCollapseShapeConsumers>(context);
populateSwapExtractWithExpandPattern(patterns);
tensor::populateFoldTensorEmptyPatterns(patterns);
scf::ForallOp::getCanonicalizationPatterns(patterns, context);
Original file line number Diff line number Diff line change
@@ -546,3 +546,57 @@ func.func @fuse_imperfectly_aligned_unpack(%arg0: tensor<5x31xf16>, %arg1: index
// CHECK: linalg.copy
// CHECK: scf.forall.in_parallel
// CHECK: return

// -----

#map = affine_map<(d0) -> (d0 * 2)>
func.func @no_fuse_non_contiguous_collapse_shape(%arg0: tensor<8x8xf32>) -> tensor<64xf32> {
%0 = tensor.empty() : tensor<8x8xf32>
%1 = scf.forall (%arg1) in (4) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) {
%2 = affine.apply #map(%arg1)
%extracted_slice = tensor.extract_slice %arg0[%2, 0] [2, 7] [1, 1] : tensor<8x8xf32> to tensor<2x7xf32>
%extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [2, 7] [1, 1] : tensor<8x8xf32> to tensor<2x7xf32>
%3 = linalg.copy ins(%extracted_slice : tensor<2x7xf32>) outs(%extracted_slice_0 : tensor<2x7xf32>) -> tensor<2x7xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg2[%2, 0] [2, 7] [1, 1] : tensor<2x7xf32> into tensor<8x8xf32>
}
} {mapping = [#gpu.thread<x>]}
%collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32>
return %collapsed : tensor<64xf32>
}

// CHECK-LABEL: func @no_fuse_non_contiguous_collapse_shape
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<8x8xf32>) {
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2x7xf32> into tensor<8x8xf32>
// CHECK: }
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]]
// CHECK: return %[[COLLAPSE]]

// -----

#map = affine_map<(d0) -> (d0 * 2)>
func.func @no_fuse_collapse_shape_rank_reduced(%arg0: tensor<8x8xf32>) -> tensor<64xf32> {
%0 = tensor.empty() : tensor<8x8xf32>
%1 = scf.forall (%arg1) in (8) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) {
%2 = affine.apply #map(%arg1)
%extracted_slice = tensor.extract_slice %arg0[%2, 0] [1, 8] [1, 1] : tensor<8x8xf32> to tensor<8xf32>
%extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [1, 8] [1, 1] : tensor<8x8xf32> to tensor<8xf32>
%3 = linalg.copy ins(%extracted_slice : tensor<8xf32>) outs(%extracted_slice_0 : tensor<8xf32>) -> tensor<8xf32>
scf.forall.in_parallel {
tensor.parallel_insert_slice %3 into %arg2[%2, 0] [1, 8] [1, 1] : tensor<8xf32> into tensor<8x8xf32>
}
} {mapping = [#gpu.thread<x>]}
%collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32>
return %collapsed : tensor<64xf32>
}

// CHECK-LABEL: func @no_fuse_collapse_shape_rank_reduced
// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<8x8xf32>) {
// CHECK: scf.forall.in_parallel {
// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<8xf32> into tensor<8x8xf32>
// CHECK: }
// CHECK: } {mapping = [#gpu.thread<x>]}
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]]
// CHECK: return %[[COLLAPSE]]
Original file line number Diff line number Diff line change
@@ -218,6 +218,54 @@ void transform_dialect::FuseForallOp::getEffects(
transform::modifiesPayload(effects);
}

//===---------------------------------------------------------------------===//
// FuseCollapseShapeIntoForallOp
//===---------------------------------------------------------------------===//

DiagnosedSilenceableFailure
transform_dialect::FuseCollapseShapeIntoForallOp::apply(
transform::TransformRewriter &rewriter,
transform::TransformResults &results, transform::TransformState &state) {
auto producers = state.getPayloadOps(getProducer());
auto consumers = state.getPayloadOps(getConsumer());

int64_t numProducers = llvm::range_size(producers);
int64_t numConsumers = llvm::range_size(consumers);
if (numProducers != 1 || numConsumers != 1) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"More than one producer or consumer");
}

auto producer = dyn_cast<scf::ForallOp>(*producers.begin());
if (!producer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-forall producer");
}
auto consumer = dyn_cast<tensor::CollapseShapeOp>(*consumers.begin());
if (!consumer) {
return mlir::emitDefiniteFailure(state.getTopLevel(),
"Non-collapse_shape consumer");
}

FailureOr<scf::ForallOp> fusedForallOp =
GPU::fuseCollapseShapeIntoProducerForall(rewriter, producer, consumer);
if (failed(fusedForallOp)) {
return mlir::emitSilenceableFailure(state.getTopLevel(),
"failed to fuse collapse_shape op");
}

results.set(getOperation()->getOpResult(0), {fusedForallOp.value()});
return DiagnosedSilenceableFailure::success();
}

void transform_dialect::FuseCollapseShapeIntoForallOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
transform::consumesHandle(getProducerMutable(), effects);
transform::consumesHandle(getConsumerMutable(), effects);
transform::producesHandle(getOperation()->getOpResults(), effects);
transform::modifiesPayload(effects);
}

} // namespace mlir::iree_compiler::IREE

void mlir::iree_compiler::registerTransformDialectIREEGPUExtension(
Original file line number Diff line number Diff line change
@@ -228,4 +228,38 @@ def FuseForallOp : Op<Transform_Dialect, "iree.fuse_forall",
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

def FuseCollapseShapeIntoForallOp : Op<Transform_Dialect, "iree.fuse_collapse_shape_into_forall",
[FunctionalStyleTransformOpTrait,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
let description = [{
Fuses a consumer tensor.collapse_shape op into a producer scf.forall op.
The users of the block argument for the corresponding forall output operand
should be only a tensor.parallel_insert_slice op, and tensor.extract_slice
ops that extract an equivalent subset. After the fusion, the output of the
forall will be collapsed, and all users of this block arg will also be
collapsed. Additional tensor.expand_shape ops will be inserted after any
tensor.extract_slice users inside the forall so that types match. Similarly,
a tensor.collapse_shape will be inserted before the
tensor.parallel_insert_slice.

#### Return modes
Emits a definite failure if either the producer is not an scf.forall op or
if the consumer is not a tensor.collapse_shape op.
}];

let arguments = (
ins TransformHandleTypeInterface:$producer,
TransformHandleTypeInterface:$consumer
);
let results = (outs TransformHandleTypeInterface:$result);

let assemblyFormat = [{
$consumer `into` $producer attr-dict
`:` functional-type(operands, results)
}];
let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect";
}

#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS
Original file line number Diff line number Diff line change
@@ -24,6 +24,7 @@ iree_lit_test_suite(
"drop_multi_mma_unit_dims.mlir",
"lower_multi_mma.mlir",
"lower_vector_barrier.mlir",
"transform_fuse_collapse_shape_with_forall.mlir",
"transform_fuse_forall.mlir",
"transform_lower_barrier_region.mlir",
"vectorize_iree_gpu_ops.mlir",
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ iree_lit_test_suite(
"drop_multi_mma_unit_dims.mlir"
"lower_multi_mma.mlir"
"lower_vector_barrier.mlir"
"transform_fuse_collapse_shape_with_forall.mlir"
"transform_fuse_forall.mlir"
"transform_lower_barrier_region.mlir"
"unroll_multi_mma.mlir"
Loading

0 comments on commit da68bc2

Please sign in to comment.