From 102bc5a6793df32aeac2278fe8c7e1e72e3e90b8 Mon Sep 17 00:00:00 2001 From: Max Dawkins Date: Tue, 3 Dec 2024 09:55:20 -0600 Subject: [PATCH] [GPU] Add pattern to fuse tensor.collapse_shape into forall producer Signed-off-by: Max Dawkins --- .../GPU/GPUFuseAndHoistParallelLoops.cpp | 22 +- .../GPU/test/gpu_fuse_and_hoist_forall.mlir | 54 +++++ .../TransformExtensions/IREEGPUExtensions.cpp | 48 ++++ .../IREEGPUExtensionsOps.td | 34 +++ .../GPU/TransformExtensions/test/BUILD.bazel | 1 + .../TransformExtensions/test/CMakeLists.txt | 1 + ...sform_fuse_collapse_shape_with_forall.mlir | 157 +++++++++++++ .../Dialect/GPU/Transforms/Transforms.cpp | 211 ++++++++++++++++++ .../Dialect/GPU/Transforms/Transforms.h | 43 ++++ .../test/ROCDL/pipeline_tile_and_fuse.mlir | 2 +- 10 files changed, 569 insertions(+), 4 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_collapse_shape_with_forall.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp index 73b5a861f0c7..f6bb86d8eba7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp @@ -11,8 +11,6 @@ #include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h" #include "iree/compiler/Codegen/Transforms/Transforms.h" #include "iree/compiler/Codegen/Utils/GPUUtils.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Casting.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -22,7 +20,6 @@ #include "mlir/Interfaces/DestinationStyleOpInterface.h" #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/LoopLikeInterface.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "iree-codegen-gpu-fuse-and-hoist-parallel-loops" @@ -325,6 +322,24 @@ struct FuseTilableForallConsumers final } }; +struct FuseCollapseShapeConsumers final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::CollapseShapeOp collapseOp, + PatternRewriter &rewriter) const override { + auto forallOp = collapseOp.getSrc().getDefiningOp(); + if (!forallOp) { + return rewriter.notifyMatchFailure(collapseOp, "No forall op producer"); + } + + if (failed(fuseCollapseShapeIntoProducerForall(rewriter, forallOp, + collapseOp))) { + return failure(); + } + return success(); + } +}; + void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { MLIRContext *context = &getContext(); @@ -369,6 +384,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); tensor::populateFoldTensorEmptyPatterns(patterns); scf::ForallOp::getCanonicalizationPatterns(patterns, context); if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) { diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir index fbf53eba63ad..1bc6608e3a3c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir @@ -518,3 +518,57 @@ func.func @no_fuse_multi_use(%2: tensor<128x128xf16>, %3: tensor<128x128xf16>) - // CHECK: scf.forall.in_parallel // CHECK: linalg.add // CHECK: return + +// ----- + +#map = affine_map<(d0) -> (d0 * 2)> +func.func @no_fuse_non_contiguous_collapse_shape(%arg0: tensor<8x8xf32>) -> tensor<64xf32> { + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%arg1) in (4) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) { + %2 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%2, 0] [2, 7] [1, 1] : tensor<8x8xf32> to tensor<2x7xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [2, 7] [1, 1] : tensor<8x8xf32> to tensor<2x7xf32> + %3 = linalg.copy ins(%extracted_slice : tensor<2x7xf32>) outs(%extracted_slice_0 : tensor<2x7xf32>) -> tensor<2x7xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg2[%2, 0] [2, 7] [1, 1] : tensor<2x7xf32> into tensor<8x8xf32> + } + } {mapping = [#gpu.thread]} + %collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32> + return %collapsed : tensor<64xf32> +} + +// CHECK-LABEL: func @no_fuse_non_contiguous_collapse_shape +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<8x8xf32>) { +// CHECK: scf.forall.in_parallel { +// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2x7xf32> into tensor<8x8xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]] +// CHECK: return %[[COLLAPSE]] + +// ----- + +#map = affine_map<(d0) -> (d0 * 2)> +func.func @no_fuse_collapse_shape_rank_reduced(%arg0: tensor<8x8xf32>) -> tensor<64xf32> { + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%arg1) in (8) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) { + %2 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%2, 0] [1, 8] [1, 1] : tensor<8x8xf32> to tensor<8xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [1, 8] [1, 1] : tensor<8x8xf32> to tensor<8xf32> + %3 = linalg.copy ins(%extracted_slice : tensor<8xf32>) outs(%extracted_slice_0 : tensor<8xf32>) -> tensor<8xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg2[%2, 0] [1, 8] [1, 1] : tensor<8xf32> into tensor<8x8xf32> + } + } {mapping = [#gpu.thread]} + %collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32> + return %collapsed : tensor<64xf32> +} + +// CHECK-LABEL: func @no_fuse_collapse_shape_rank_reduced +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<8x8xf32>) { +// CHECK: scf.forall.in_parallel { +// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<8xf32> into tensor<8x8xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]] +// CHECK: return %[[COLLAPSE]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp index 3ee4db9c9e7c..e745ba4b9203 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp @@ -218,6 +218,54 @@ void transform_dialect::FuseForallOp::getEffects( transform::modifiesPayload(effects); } +//===---------------------------------------------------------------------===// +// FuseCollapseShapeWithForallOp +//===---------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform_dialect::FuseCollapseShapeWithForallOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + auto producers = state.getPayloadOps(getProducer()); + auto consumers = state.getPayloadOps(getConsumer()); + + int64_t numProducers = llvm::range_size(producers); + int64_t numConsumers = llvm::range_size(consumers); + if (numProducers != 1 || numConsumers != 1) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "More than one producer or consumer"); + } + + auto producer = dyn_cast(*producers.begin()); + if (!producer) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "Non-forall producer"); + } + auto consumer = dyn_cast(*consumers.begin()); + if (!consumer) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "Non-collapse_shape consumer"); + } + + FailureOr fusedForallOp = + GPU::fuseCollapseShapeIntoProducerForall(rewriter, producer, consumer); + if (failed(fusedForallOp)) { + return mlir::emitSilenceableFailure(state.getTopLevel(), + "failed to fuse collapse_shape op"); + } + + results.set(getOperation()->getOpResult(0), {fusedForallOp.value()}); + return DiagnosedSilenceableFailure::success(); +} + +void transform_dialect::FuseCollapseShapeWithForallOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getProducerMutable(), effects); + transform::consumesHandle(getConsumerMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + } // namespace mlir::iree_compiler::IREE void mlir::iree_compiler::registerTransformDialectIREEGPUExtension( diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td index ceba3f086b99..e570ce62889e 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td @@ -228,4 +228,38 @@ def FuseForallOp : Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Fuses a consumer tensor.collapse_shape op into a producer scf.forall op. + The users of the block argument for the corresponding forall output operand + should be only a tensor.parallel_insert_slice op, and tensor.extract_slice + ops that extract an equivalent subset. After the fusion, the output of the + forall will be collapsed, and all users of this block arg will also be + collapsed. Additional tensor.expand_shape ops will be inserted after any + tensor.extract_slice users inside the forall so that types match. Similarly, + a tensor.collapse_shape will be inserted before the + tensor.parallel_insert_slice. + + #### Return modes + Emits a definite failure if either the producer is not an scf.forall op or + if the consumer is not a tensor.collapse_shape op. + }]; + + let arguments = ( + ins TransformHandleTypeInterface:$producer, + TransformHandleTypeInterface:$consumer + ); + let results = (outs TransformHandleTypeInterface:$result); + + let assemblyFormat = [{ + $consumer `into` $producer attr-dict + `:` functional-type(operands, results) + }]; + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; +} + #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel index 1b99c9f19422..428211b3ea01 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel @@ -24,6 +24,7 @@ iree_lit_test_suite( "drop_multi_mma_unit_dims.mlir", "lower_multi_mma.mlir", "lower_vector_barrier.mlir", + "transform_fuse_collapse_shape_with_forall.mlir", "transform_fuse_forall.mlir", "transform_lower_barrier_region.mlir", "vectorize_iree_gpu_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt index 04481680b3ac..344da8cf34d9 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt @@ -20,6 +20,7 @@ iree_lit_test_suite( "drop_multi_mma_unit_dims.mlir" "lower_multi_mma.mlir" "lower_vector_barrier.mlir" + "transform_fuse_collapse_shape_with_forall.mlir" "transform_fuse_forall.mlir" "transform_lower_barrier_region.mlir" "unroll_multi_mma.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_collapse_shape_with_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_collapse_shape_with_forall.mlir new file mode 100644 index 000000000000..aa562ee7580e --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_collapse_shape_with_forall.mlir @@ -0,0 +1,157 @@ +// RUN: iree-opt %s -iree-transform-dialect-interpreter -transform-dialect-drop-schedule -canonicalize -cse --split-input-file | FileCheck %s + +#map = affine_map<(d0) -> (d0 * 2)> +module { + func.func @fuse_collapse_shape_with_forall(%arg0: tensor<8x8xf32>) -> tensor<64xf32> { + %0 = tensor.empty() : tensor<8x8xf32> + %1 = scf.forall (%arg1) in (4) shared_outs(%arg2 = %0) -> (tensor<8x8xf32>) { + %2 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg0[%2, 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32> + %extracted_slice_0 = tensor.extract_slice %arg2[%2, 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32> + %3 = linalg.copy ins(%extracted_slice : tensor<2x8xf32>) outs(%extracted_slice_0 : tensor<2x8xf32>) -> tensor<2x8xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg2[%2, 0] [2, 8] [1, 1] : tensor<2x8xf32> into tensor<8x8xf32> + } + } {mapping = [#gpu.thread]} + %collapsed = tensor.collapse_shape %1 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32> + return %collapsed : tensor<64xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %producer = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_collapse_shape_with_forall %consumer into %producer + : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 2)> + +// CHECK-LABEL: func @fuse_collapse_shape_with_forall +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<8x8xf32> + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32> +// CHECK: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[EMPTY]] {{\[}}[0, 1]] : tensor<8x8xf32> into tensor<64xf32> +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall (%[[IDX:.+]]) in (4) shared_outs(%[[COLLAPSED_BBARG:.+]] = %[[COLLAPSED_OUT]]) -> (tensor<64xf32>) { +// CHECK-DAG: %[[EXPANDED_BBARG:.+]] = tensor.expand_shape %[[COLLAPSED_BBARG]] +// CHECK-SAME: {{\[}}[0, 1]] output_shape [8, 8] : tensor<64xf32> into tensor<8x8xf32> +// CHECK-DAG: %[[SLICE_IDX_0:.+]] = affine.apply #[[$MAP]](%[[IDX]]) +// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[SLICE_IDX_0]], 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32> +// CHECK-DAG: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[EXPANDED_BBARG]][%[[SLICE_IDX_0]], 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32> +// CHECK-DAG: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<2x8xf32>) outs(%[[OUT_SLICE]] : tensor<2x8xf32>) -> tensor<2x8xf32> +// CHECK-DAG: %[[LINEAR_SLICE_IDX:.+]] = affine.linearize_index [%[[SLICE_IDX_0]], %[[C0]]] by (8, 8) : index +// CHECK-DAG: %[[COLLAPSED_COPY:.+]] = tensor.collapse_shape %[[COPY]] {{\[}}[0, 1]] : tensor<2x8xf32> into tensor<16xf32> +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[COLLAPSED_COPY]] into %[[COLLAPSED_BBARG]][%[[LINEAR_SLICE_IDX]]] [16] [1] : tensor<16xf32> into tensor<64xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: return %[[FORALL_RESULT]] + +// ----- + +#map = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +module { + func.func @fuse_dynamic_collapse_shape_with_forall(%arg0: tensor, %arg1: index, %arg2: index) -> tensor { + %0 = tensor.empty(%arg1, %arg2) : tensor + %1 = scf.forall (%arg3, %arg4) = (0, 0) to (%arg1, %arg2) step (4, 4) shared_outs(%arg5 = %0) -> (tensor) { + %2 = affine.min #map(%arg3)[%arg1] + %3 = affine.min #map(%arg4)[%arg2] + %extracted_slice = tensor.extract_slice %arg0[%arg3, %arg4, 0] [%2, %3, 8] [1, 1, 1] : tensor to tensor + %extracted_slice_0 = tensor.extract_slice %arg5[%arg3, %arg4, 0] [%2, %3, 8] [1, 1, 1] : tensor to tensor + %4 = linalg.copy ins(%extracted_slice : tensor) outs(%extracted_slice_0 : tensor) -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg5[%arg3, %arg4, 0] [%2, %3, 8] [1, 1, 1] : tensor into tensor + } + } {mapping = [#gpu.thread, #gpu.thread]} + %collapsed = tensor.collapse_shape %1 [[0], [1, 2]] : tensor into tensor + return %collapsed : tensor + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_collapse_shape_with_forall %1 into %0 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 * 8)> + +// CHECK-LABEL: func @fuse_dynamic_collapse_shape_with_forall +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[SIZE0:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[SIZE1:[A-Za-z0-9]+]]: index + +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SIZE0]], %[[SIZE1]]) : tensor +// CHECK: %[[COLLAPSED_OUT:.+]] = tensor.collapse_shape %[[EMPTY]] {{\[}}[0], [1, 2]] : tensor into tensor +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall (%[[IDX0:.+]], %[[IDX1:.+]]) = (0, 0) to (%[[SIZE0]], %[[SIZE1]]) step (4, 4) +// CHECK-SAME: shared_outs(%[[COLLAPSED_BBARG:.+]] = %[[COLLAPSED_OUT]]) -> (tensor) { +// CHECK-DAG: %[[EXPANDED_BBARG:.+]] = tensor.expand_shape %[[COLLAPSED_BBARG]] +// CHECK-SAME: {{\[}}[0], [1, 2]] output_shape [%[[SIZE0]], %[[SIZE1]], 8] : tensor into tensor +// CHECK-DAG: %[[SLICE_SIZE_0:.+]] = affine.min #map(%[[IDX0]])[%[[SIZE0]]] +// CHECK-DAG: %[[SLICE_SIZE_1:.+]] = affine.min #map(%[[IDX1]])[%[[SIZE1]]] +// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IDX0]], %[[IDX1]], 0]{{.*}}[%[[SLICE_SIZE_0]], %[[SLICE_SIZE_1]], 8] [1, 1, 1] : tensor to tensor +// CHECK-DAG: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[EXPANDED_BBARG]] +// CHECK-SAME: [%[[IDX0]], %[[IDX1]], 0] [%[[SLICE_SIZE_0]], %[[SLICE_SIZE_1]], 8] [1, 1, 1] : tensor to tensor +// CHECK-DAG: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor) outs(%[[OUT_SLICE]] : tensor) -> tensor +// CHECK-DAG: %[[LINEAR_SLICE_IDX:.+]] = affine.linearize_index [%[[IDX1]], %[[C0]]] by (%[[SIZE1]], 8) : index +// CHECK-DAG: %[[COLLAPSED_SLICE_SIZE:.+]] = affine.apply #[[$MAP1]](%[[SLICE_SIZE_1]]) +// CHECK-DAG: %[[COLLAPSED_COPY:.+]] = tensor.collapse_shape %[[COPY]] {{\[}}[0], [1, 2]] : tensor into tensor +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[COLLAPSED_COPY]] into %[[COLLAPSED_BBARG]] +// CHECK-SAME: [%[[IDX0]], %[[LINEAR_SLICE_IDX]]] [%[[SLICE_SIZE_0]], %[[COLLAPSED_SLICE_SIZE]]] [1, 1] : tensor into tensor +// CHECK: } +// CHECK: } {mapping = [#gpu.thread, #gpu.thread]} +// CHECK: return %[[FORALL_RESULT]] + +// ----- + +#map = affine_map<(d0) -> (d0 * 2)> +module { + func.func @fuse_collapse_shape_with_multi_result_forall(%arg0: tensor<8x8xf32>, %arg1: tensor<8x8xf16>) -> (tensor<64xf32>, tensor<8x8xf16>) { + %0 = tensor.empty() : tensor<8x8xf32> + %1 = tensor.empty() : tensor<8x8xf16> + %2:2 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0, %arg4 = %1) -> (tensor<8x8xf32>, tensor<8x8xf16>) { + %3 = affine.apply #map(%arg2) + %extracted_slice = tensor.extract_slice %arg0[%3, 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32> + %extracted_slice_0 = tensor.extract_slice %arg3[%3, 0] [2, 8] [1, 1] : tensor<8x8xf32> to tensor<2x8xf32> + %extracted_slice_1 = tensor.extract_slice %arg1[%3, 0] [2, 8] [1, 1] : tensor<8x8xf16> to tensor<2x8xf16> + %extracted_slice_2 = tensor.extract_slice %arg4[%3, 0] [2, 8] [1, 1] : tensor<8x8xf16> to tensor<2x8xf16> + %4 = linalg.copy ins(%extracted_slice : tensor<2x8xf32>) outs(%extracted_slice_0 : tensor<2x8xf32>) -> tensor<2x8xf32> + %5 = linalg.copy ins(%extracted_slice_1 : tensor<2x8xf16>) outs(%extracted_slice_2 : tensor<2x8xf16>) -> tensor<2x8xf16> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg3[%3, 0] [2, 8] [1, 1] : tensor<2x8xf32> into tensor<8x8xf32> + tensor.parallel_insert_slice %5 into %arg4[%3, 0] [2, 8] [1, 1] : tensor<2x8xf16> into tensor<8x8xf16> + } + } {mapping = [#gpu.thread]} + %collapsed = tensor.collapse_shape %2#0 [[0, 1]] : tensor<8x8xf32> into tensor<64xf32> + return %collapsed, %2#1 : tensor<64xf32>, tensor<8x8xf16> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_collapse_shape_with_forall %1 into %0 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @fuse_collapse_shape_with_multi_result_forall +// CHECK: %[[FORALL_RESULT:.+]]:2 = scf.forall {{.*}} -> (tensor<64xf32>, tensor<8x8xf16>) { +// CHECK: scf.forall.in_parallel { +// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<16xf32> into tensor<64xf32> +// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2x8xf16> into tensor<8x8xf16> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: return %[[FORALL_RESULT]]#0, %[[FORALL_RESULT]]#1 diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index 75bf5e51d54c..8f78e5105982 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -287,6 +287,217 @@ LogicalResult fuseForallIntoConsumer(RewriterBase &rewriter, return success(); } +/// Return whether a parallel insert slice operation can be collapsed with +/// the given reassociation indices. For a slice to be collapsable, each group +/// of collapsed dimensions must be fully contiguous in the destination type. +static LogicalResult +collapsableSlicePrecondition(RewriterBase &rewriter, + tensor::ParallelInsertSliceOp sliceOp, + SmallVector reassociations) { + if (!areAllConstantIntValue(sliceOp.getMixedStrides(), 1)) { + return rewriter.notifyMatchFailure(sliceOp, "strides are not all 1"); + } + SmallVector sizes = sliceOp.getMixedSizes(); + RankedTensorType fullTensorType = sliceOp.getDestType(); + ArrayRef destShape = fullTensorType.getShape(); + for (auto group : reassociations) { + bool isFullSlice = true; + for (auto idx : llvm::reverse(group)) { + std::optional constSize = getConstantIntValue(sizes[idx]); + // If the size is dynamic, then conservatively assume it is not full. + if (!constSize.has_value()) { + if (!isFullSlice) { + return rewriter.notifyMatchFailure( + sliceOp, + "parallel insert slice is not contiguous in the destination"); + } + isFullSlice = false; + continue; + } + // Unit dimensions are always valid. + if (constSize.value() == 1) { + continue; + } + // If the size is not unit, then the slice must be full so far. + if (!isFullSlice) { + return rewriter.notifyMatchFailure( + sliceOp, + "parallel insert slice is not contiguous in the destination"); + } + if (constSize.value() != destShape[idx]) { + isFullSlice = false; + } + } + } + + RankedTensorType sliceType = sliceOp.getSourceType(); + if (sliceOp.getMixedSizes().size() != sliceType.getRank()) { + return rewriter.notifyMatchFailure( + sliceOp, "parallel insert slice is rank reducing"); + } + return success(); +} + +/// Collapse all `ops` with the given `reassociations`. All `ops` are expected +/// to have equivalent offsets, sizes, and strides. All strides are expected to +/// be 1. This function assumes that the parallelInsertOp passes the +/// collapsableSlicePrecondition. +static tensor::ParallelInsertSliceOp +collapseParallelInsertOp(RewriterBase &rewriter, + tensor::ParallelInsertSliceOp parallelInsertOp, + SmallVector reassociations) { + // Compute the collapsed offsets, sizes, and strides. + rewriter.setInsertionPoint(parallelInsertOp.getParallelCombiningParent()); + Location loc = parallelInsertOp.getParallelCombiningParent()->getLoc(); + int64_t resultIdx = parallelInsertOp.getTiedOpResult().getResultNumber(); + auto forallOp = parallelInsertOp->getParentOfType(); + Value loopInit = forallOp.getOutputs()[resultIdx]; + SmallVector mixedInitSizes = + tensor::getMixedSizes(rewriter, loc, loopInit); + auto prod = [&](ArrayRef vals) -> OpFoldResult { + auto mulMap = AffineMap::get( + 2, 0, {rewriter.getAffineDimExpr(0) * rewriter.getAffineDimExpr(1)}); + OpFoldResult product = rewriter.getIndexAttr(1); + for (OpFoldResult val : vals) { + product = affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap, + {product, val}); + } + return product; + }; + SmallVector offsets = parallelInsertOp.getMixedOffsets(); + SmallVector sizes = parallelInsertOp.getMixedSizes(); + SmallVector newSizes, newOffsets; + for (auto group : reassociations) { + if (group.size() == 1) { + newOffsets.push_back(offsets[group[0]]); + newSizes.push_back(sizes[group[0]]); + continue; + } + ArrayRef basis(mixedInitSizes.begin() + group.front(), + mixedInitSizes.begin() + group.back() + 1); + ArrayRef groupOffsets(offsets.begin() + group.front(), + offsets.begin() + group.back() + 1); + SmallVector offsetVals = + llvm::map_to_vector(groupOffsets, [&](OpFoldResult ofr) { + return getValueOrCreateConstantIndexOp(rewriter, loc, ofr); + }); + OpFoldResult collapsedOffset = + rewriter.create(loc, offsetVals, basis) + .getResult(); + ArrayRef groupSizes(sizes.begin() + group.front(), + sizes.begin() + group.back() + 1); + OpFoldResult collapsedSize = prod(groupSizes); + newOffsets.push_back(collapsedOffset); + newSizes.push_back(collapsedSize); + } + SmallVector newStrides(newSizes.size(), + rewriter.getIndexAttr(1)); + + // Collapse the slice source. + loc = parallelInsertOp.getParallelCombiningParent()->getLoc(); + rewriter.setInsertionPoint(parallelInsertOp.getParallelCombiningParent()); + auto newCollapse = rewriter.create( + loc, parallelInsertOp.getSource(), reassociations); + + // Collapse the parallel insert slice. + rewriter.setInsertionPoint(parallelInsertOp); + auto newInsertOp = rewriter.replaceOpWithNewOp( + parallelInsertOp, newCollapse, parallelInsertOp.getDest(), newOffsets, + newSizes, newStrides); + return newInsertOp; +} + +FailureOr +fuseCollapseShapeIntoProducerForall(RewriterBase &rewriter, + scf::ForallOp forallOp, + tensor::CollapseShapeOp collapseOp) { + // Check that there is a single user of the collapsed result. + auto forallResult = cast(collapseOp.getSrc()); + if (!forallResult.hasOneUse()) { + return rewriter.notifyMatchFailure(forallOp, + "forall result has multiple uses"); + } + + // Get the result's corresponding parallel_insert_slice op. + SmallVector parallelInsertOps = forallOp.getCombiningOps( + forallOp.getRegionIterArgs()[forallResult.getResultNumber()]); + if (parallelInsertOps.size() != 1) { + return rewriter.notifyMatchFailure( + forallOp, "Expected a single parallel_insert_slice"); + } + + auto parallelInsertOp = + dyn_cast(parallelInsertOps.front()); + if (!parallelInsertOp) { + return rewriter.notifyMatchFailure( + forallOp, "Expected parallel_insert_slice combining op"); + } + + // Collapse the parallel insert slice op. + SmallVector reassociations = + collapseOp.getReassociationIndices(); + if (failed(collapsableSlicePrecondition(rewriter, parallelInsertOp, + reassociations))) { + return failure(); + } + tensor::ParallelInsertSliceOp newParallelInsertOp = + collapseParallelInsertOp(rewriter, parallelInsertOp, reassociations); + + // At this point, the newParallelInsertOp still has the destination of the + // original parallel insert op, so the destination is the original expanded + // init block argument, and we can use it to get the sizes for the expand. + // The block argument will be corrected later, when the forall op is replaced. + Value initArg = newParallelInsertOp.getDest(); + Value forallOutput = forallOp.getOutputs()[forallResult.getResultNumber()]; + Location loc = forallOutput.getLoc(); + rewriter.setInsertionPointAfterValue(forallOutput); + SmallVector initSizes = + tensor::getMixedSizes(rewriter, loc, forallOutput); + loc = initArg.getLoc(); + rewriter.setInsertionPointToStart(forallOp.getBody()); + auto expandedInitArg = rewriter.create( + loc, initArg.getType(), initArg, reassociations, initSizes); + + // The new parallel insert slice is collapsed, so don't use the expanded init. + // Also don't replace the expand shape src with its own result. + rewriter.replaceUsesWithIf( + initArg, expandedInitArg.getResult(), [&](OpOperand &operand) { + return operand != expandedInitArg.getSrcMutable() && + operand != newParallelInsertOp.getDestMutable(); + }); + + // Now create a new scf::Forall with a collapsed loop init. + loc = forallOp->getLoc(); + rewriter.setInsertionPoint(forallOp); + SmallVector newForallOutputs(forallOp.getOutputs()); + Value collapsedLoopInit = rewriter.create( + loc, newForallOutputs[forallResult.getResultNumber()], reassociations); + newForallOutputs[forallResult.getResultNumber()] = collapsedLoopInit; + + scf::ForallOp newForallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), newForallOutputs, forallOp.getMappingAttr()); + + SmallVector argReplacements(newForallOp.getInductionVars()); + argReplacements.append(newForallOp.getRegionIterArgs().begin(), + newForallOp.getRegionIterArgs().end()); + newForallOp.getTerminator()->erase(); + rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(), + argReplacements); + + // Replaces the uses of the old scf.forall with the new scf.forall + rewriter.replaceOp(collapseOp, + newForallOp.getResult(forallResult.getResultNumber())); + for (int idx = 0; idx < forallOp->getNumResults(); ++idx) { + if (idx == forallResult.getResultNumber()) { + continue; + } + forallOp->getResult(idx).replaceAllUsesWith(newForallOp->getResult(idx)); + } + rewriter.eraseOp(forallOp); + return newForallOp; +} + //===----------------------------------------------------------------------===// // MultiMmaOp Lowering //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index 515ed84e70c3..444762b09ed7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -56,6 +56,49 @@ LogicalResult fuseForallIntoConsumer(RewriterBase &rewriter, scf::ForallOp consumer, SmallVector consumerChain); +/// Function to fuse a collapse shape op into a forall op producer. This +/// rewrite effectively bubbles the collapse_shape op up through the forall +/// output operand, and the block argument inside the forall becomes expanded +/// with the reassociation indices of the collapse. The parallel_insert_slice +/// for the collapsed init will be collapsed, and an expand_shape on the loop +/// init block argument will be added to ensure that types match for users other +/// than the parallel_insert_slice. The following example illustrates a simple +/// case of this transformation: +/// +/// ``` +/// %forall = scf.forall ... shared_outs(%arg = %dest) -> tensor<4x4xf32> { +/// %user = "some_user" %arg +/// ... +/// scf.in_parallel { +/// tensor.parallel_insert_slice %val into %arg ... +/// tensor<1x4xf32> into tensor<4x4xf32> +/// } +/// } +/// %collapse = tensor.collapse_shape %forall ... +/// tensor<4x4xf32> into tensor<16xf32> +/// ``` +/// After the transformation this would become: +/// ``` +/// %collapse = tensor.collapse_shape %dest ... +/// tensor<4x4xf32> into tensor<16xf32> +/// %forall = scf.forall ... shared_outs(%arg = %collapse) -> tensor<16xf32> { +/// %expanded_arg = tensor.expand_shape %arg ... +/// tensor<16xf32> to tensor<4x4xf32> +/// %user = "some_user" %expanded_arg +/// ... +/// %collapsed_val = tensor.collapse_shape %val ... +/// tensor<1x4xf32> to tensor<4xf32> +/// scf.in_parallel { +/// tensor.parallel_insert_slice %collapsed_val into %arg ... +/// tensor<4xf32> into tensor<16xf32> +/// } +/// } +/// ``` +FailureOr +fuseCollapseShapeIntoProducerForall(RewriterBase &rewriter, + scf::ForallOp forallOp, + tensor::CollapseShapeOp collapseOp); + // Helper to convert a contraction-like linalg op to an iree_gpu.multi_mma. FailureOr convertContractionToMultiMma(RewriterBase &rewriter, linalg::LinalgOp linalgOp, diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index a716c6b7579c..2354e8c1ad8e 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -1065,7 +1065,7 @@ hal.executable public @main { // CHECK-DAG: %[[B2:.+]] = hal.interface.binding.subspan layout({{.+}}) binding(2) // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> // CHECK-DAG: memref.alloc() : memref<64x36xf16, #gpu.address_space> -// CHECK-DAG: memref.alloc() : memref<4x16x4x16xf32, #gpu.address_space> +// CHECK-DAG: memref.alloc() : memref<64x66xf32, #gpu.address_space> // CHECK: scf.forall ({{.*}}) in (32, 160) { // CHECK: %[[LOOP:.+]] = scf.for %[[IV:.+]] = %c0 to %c80 step %c2 {{.*}} -> (vector<2x2x4x1xf32>) // CHECK: gpu.barrier