From fc8dea94129a2d8c194b2ff5c3524518fce9c0cb Mon Sep 17 00:00:00 2001 From: Max191 <44243577+Max191@users.noreply.github.com> Date: Tue, 28 Jan 2025 14:47:02 -0500 Subject: [PATCH] [GPU] Add pattern to fuse tensor.extract_slice into forall producer (#19296) This PR adds a pattern to fuse a consumer tensor.extract_slice into a producer scf.forall op. The transform is added to FuseAndHoistParallelLoops, where it helps to fuse tensor.unpack ops with extract_slice semantics into producer loops. This is needed when targeting MFMA intrinsics for unaligned shapes, and also in generating code for unset encoding ops on GPU. This is a follow up to https://github.com/iree-org/iree/pull/19295, which has the complementing pattern for collapse_shape. The PR also adds a transform op to keep the long lit tests separate from the FuseAndHoistParallelLoop tests. --------- Signed-off-by: Max Dawkins Signed-off-by: Max Dawkins Co-authored-by: Max Dawkins Signed-off-by: Hyunsung Lee --- .../GPU/GPUFuseAndHoistParallelLoops.cpp | 22 ++ .../GPU/test/gpu_fuse_and_hoist_forall.mlir | 27 ++ .../TransformExtensions/IREEGPUExtensions.cpp | 48 +++ .../IREEGPUExtensionsOps.td | 34 ++ .../GPU/TransformExtensions/test/BUILD.bazel | 1 + .../TransformExtensions/test/CMakeLists.txt | 1 + ...nsform_fuse_extract_slice_with_forall.mlir | 310 ++++++++++++++++++ .../Dialect/GPU/Transforms/Transforms.cpp | 256 ++++++++++++++- .../Dialect/GPU/Transforms/Transforms.h | 45 +++ 9 files changed, 741 insertions(+), 3 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_extract_slice_with_forall.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp index 5e47104d9b58d..769866f498078 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp @@ -346,6 +346,27 @@ struct FuseCollapseShapeConsumers final } }; +struct FuseExtractSliceConsumers final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const override { + // Find the scf::ForallOp producer, and get the corresponding + // tensor::ParallelInsertSliceOp. + auto forallOp = extractSliceOp.getSource().getDefiningOp(); + if (!forallOp) { + return rewriter.notifyMatchFailure(extractSliceOp, + "No forall op producer"); + } + + if (failed(fuseExtractSliceIntoProducerForall(rewriter, forallOp, + extractSliceOp))) { + return failure(); + } + return success(); + } +}; + void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { MLIRContext *context = &getContext(); @@ -391,6 +412,7 @@ void GPUFuseAndHoistParallelLoopsPass::runOnOperation() { patterns.add(context); patterns.add(context); patterns.add(context); + patterns.add(context); populateSwapExtractWithExpandPattern(patterns); tensor::populateFoldTensorEmptyPatterns(patterns); scf::ForallOp::getCanonicalizationPatterns(patterns, context); diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir index a019876cf3eae..76a30902610fa 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_fuse_and_hoist_forall.mlir @@ -600,3 +600,30 @@ func.func @no_fuse_collapse_shape_rank_reduced(%arg0: tensor<8x8xf32>) -> tensor // CHECK: } {mapping = [#gpu.thread]} // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]] // CHECK: return %[[COLLAPSE]] + +// ----- + +#map = affine_map<(d0) -> (d0 * 2)> +func.func @no_fuse_extract_slice_rank_reduced(%arg0: tensor<4x8xf32>, %size1: index) -> tensor { + %0 = tensor.empty() : tensor<4x8xf32> + %1 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0) -> (tensor<4x8xf32>) { + %2 = affine.apply #map(%arg2) + %extracted_slice_0 = tensor.extract_slice %arg0[0, %2] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32> + %extracted_slice_1 = tensor.extract_slice %arg3[0, %2] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32> + %3 = linalg.copy ins(%extracted_slice_0 : tensor<2xf32>) outs(%extracted_slice_1 : tensor<2xf32>) -> tensor<2xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[0, %2] [1, 2] [1, 1] : tensor<2xf32> into tensor<4x8xf32> + } + } {mapping = [#gpu.thread]} + %extracted_slice = tensor.extract_slice %1[0, 0] [1, %size1] [1, 1] : tensor<4x8xf32> to tensor + return %extracted_slice : tensor +} + +// CHECK-LABEL: func @no_fuse_extract_slice_rank_reduced +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall {{.*}} -> (tensor<4x8xf32>) { +// CHECK: scf.forall.in_parallel { +// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2xf32> into tensor<4x8xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[FORALL_RESULT]] +// CHECK: return %[[EXTRACT]] diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp index ec2a53dd2c267..4059b34c56574 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensions.cpp @@ -266,6 +266,54 @@ void transform_dialect::FuseCollapseShapeIntoForallOp::getEffects( transform::modifiesPayload(effects); } +//===---------------------------------------------------------------------===// +// FuseExtractSliceIntoForallOp +//===---------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform_dialect::FuseExtractSliceIntoForallOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { + auto producers = state.getPayloadOps(getProducer()); + auto consumers = state.getPayloadOps(getConsumer()); + + int64_t numProducers = llvm::range_size(producers); + int64_t numConsumers = llvm::range_size(consumers); + if (numProducers != 1 || numConsumers != 1) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "More than one producer or consumer"); + } + + auto producer = dyn_cast(*producers.begin()); + if (!producer) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "Non-forall producer"); + } + auto consumer = dyn_cast(*consumers.begin()); + if (!consumer) { + return mlir::emitDefiniteFailure(state.getTopLevel(), + "Non-extract_slice consumer"); + } + + FailureOr fusedForallOp = + GPU::fuseExtractSliceIntoProducerForall(rewriter, producer, consumer); + if (failed(fusedForallOp)) { + return mlir::emitSilenceableFailure(*this, + "failed to fuse extract_slice op"); + } + + results.set(getOperation()->getOpResult(0), {fusedForallOp.value()}); + return DiagnosedSilenceableFailure::success(); +} + +void transform_dialect::FuseExtractSliceIntoForallOp::getEffects( + SmallVectorImpl &effects) { + transform::consumesHandle(getProducerMutable(), effects); + transform::consumesHandle(getConsumerMutable(), effects); + transform::producesHandle(getOperation()->getOpResults(), effects); + transform::modifiesPayload(effects); +} + } // namespace mlir::iree_compiler::IREE void mlir::iree_compiler::registerTransformDialectIREEGPUExtension( diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td index 2a1a7309cbc5e..38b4d2081d578 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/IREEGPUExtensionsOps.td @@ -262,4 +262,38 @@ def FuseCollapseShapeIntoForallOp : Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Fuses a consumer tensor.extract_slice op into a producer scf.forall op. + This transform is supported if the extract_slice op has all zero offsets, + and if all the offsets, sizes, and strides dominate the scf.forall op. + After the transformation, the forall loop output argument corresponding + to the sliced result will be replaced with a slice of it with the same + offsets, sizes, and strides as the original extract_slice. The source of + the corresponding tensor.parallel_insert_slice of the scf.forall will also + become a slice of the original parallel insert source, clamped to fit within + the new sliced result tensor. + + #### Return modes + Emits a definite failure if either the producer is not an scf.forall op or + if the consumer is not a tensor.extract_slice op. + }]; + + let arguments = ( + ins TransformHandleTypeInterface:$producer, + TransformHandleTypeInterface:$consumer + ); + let results = (outs TransformHandleTypeInterface:$result); + + let assemblyFormat = [{ + $consumer `into` $producer attr-dict + `:` functional-type(operands, results) + }]; + let cppNamespace = "mlir::iree_compiler::IREE::transform_dialect"; +} + #endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_TRANSFORMEXTENSIONS_IREEGPUEXTENSIONS diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel index 428211b3ea01b..c137bef9afbe2 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/BUILD.bazel @@ -25,6 +25,7 @@ iree_lit_test_suite( "lower_multi_mma.mlir", "lower_vector_barrier.mlir", "transform_fuse_collapse_shape_with_forall.mlir", + "transform_fuse_extract_slice_with_forall.mlir", "transform_fuse_forall.mlir", "transform_lower_barrier_region.mlir", "vectorize_iree_gpu_ops.mlir", diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt index 344da8cf34d9e..abeff344d337a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/CMakeLists.txt @@ -21,6 +21,7 @@ iree_lit_test_suite( "lower_multi_mma.mlir" "lower_vector_barrier.mlir" "transform_fuse_collapse_shape_with_forall.mlir" + "transform_fuse_extract_slice_with_forall.mlir" "transform_fuse_forall.mlir" "transform_lower_barrier_region.mlir" "unroll_multi_mma.mlir" diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_extract_slice_with_forall.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_extract_slice_with_forall.mlir new file mode 100644 index 0000000000000..c6a75421e83be --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/transform_fuse_extract_slice_with_forall.mlir @@ -0,0 +1,310 @@ +// RUN: iree-opt %s -iree-transform-dialect-interpreter --verify-diagnostics -transform-dialect-drop-schedule -canonicalize -cse --split-input-file | FileCheck %s + +#map = affine_map<(d0) -> (d0 * 2)> +module { + func.func @fuse_extract_slice_into_forall(%arg0: tensor<8xf32>, %arg1: index) -> tensor { + %0 = tensor.empty() : tensor<8xf32> + %1 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0) -> (tensor<8xf32>) { + %2 = affine.apply #map(%arg2) + %extracted_slice_0 = tensor.extract_slice %arg0[%2] [2] [1] : tensor<8xf32> to tensor<2xf32> + %extracted_slice_1 = tensor.extract_slice %arg3[%2] [2] [1] : tensor<8xf32> to tensor<2xf32> + %3 = linalg.copy ins(%extracted_slice_0 : tensor<2xf32>) outs(%extracted_slice_1 : tensor<2xf32>) -> tensor<2xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%2] [2] [1] : tensor<2xf32> into tensor<8xf32> + } + } {mapping = [#gpu.thread]} + %extracted_slice = tensor.extract_slice %1[0] [%arg1] [1] : tensor<8xf32> to tensor + return %extracted_slice : tensor + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %producer = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer = transform.get_consumers_of_result %producer[0] : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_extract_slice_into_forall %consumer into %producer + : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * -2 + s0, 0)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (2, d0)> + +// CHECK-LABEL: func @fuse_extract_slice_into_forall +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<8xf32> +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: index + +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8xf32> +// CHECK-DAG: %[[SLICED_OUT:.+]] = tensor.extract_slice %[[EMPTY]][0] [%[[ARG1]]] [1] : tensor<8xf32> to tensor +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall (%[[IDX:.+]]) in (4) shared_outs(%[[SLICED_BBARG:.+]] = %[[SLICED_OUT]]) -> (tensor) { + +// CHECK-DAG: %[[SLICE_IDX:.+]] = affine.apply #[[$MAP]](%[[IDX]]) +// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[SLICE_IDX]]] [2] [1] : tensor<8xf32> to tensor<2xf32> +// CHECK-DAG: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[EMPTY]][%[[SLICE_IDX]]] [2] [1] : tensor<8xf32> to tensor<2xf32> +// CHECK-DAG: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<2xf32>) outs(%[[OUT_SLICE]] : tensor<2xf32>) -> tensor<2xf32> + +// CHECK-DAG: %[[SIZE_CLAMPED_LOW:.+]] = affine.max #[[$MAP1]](%[[IDX]])[%[[ARG1]]] +// CHECK-DAG: %[[SIZE_CLAMPED_HIGH:.+]] = affine.min #[[$MAP2]](%[[SIZE_CLAMPED_LOW]]) + +// CHECK: %[[SLICED_COPY:.+]] = tensor.extract_slice %[[COPY]][0] [%[[SIZE_CLAMPED_HIGH]]] [1] : tensor<2xf32> to tensor +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[SLICED_COPY]] into %[[SLICED_BBARG]][%[[SLICE_IDX]]] [%[[SIZE_CLAMPED_HIGH]]] [1] : tensor into tensor +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: return %[[FORALL_RESULT]] + +// ----- + +module { + module { + func.func @fuse_dynamic_extract_slice_into_forall(%arg0: tensor, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index) -> tensor { + %0 = tensor.empty(%arg1, %arg2) : tensor + %1 = scf.forall (%arg7, %arg8) = (0, 0) to (%arg1, %arg2) step (%arg5, %arg6) shared_outs(%arg9 = %0) -> (tensor) { + %extracted_slice_0 = tensor.extract_slice %arg0[%arg7, %arg8] [%arg5, %arg6] [1, 1] : tensor to tensor + %extracted_slice_1 = tensor.extract_slice %arg9[%arg7, %arg8] [%arg5, %arg6] [1, 1] : tensor to tensor + %2 = linalg.copy ins(%extracted_slice_0 : tensor) outs(%extracted_slice_1 : tensor) -> tensor + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg9[%arg7, %arg8] [%arg5, %arg6] [1, 1] : tensor into tensor + } + } {mapping = [#gpu.thread, #gpu.thread]} + %extracted_slice = tensor.extract_slice %1[0, 0] [%arg3, %arg4] [1, 1] : tensor to tensor + return %extracted_slice : tensor + } + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_consumers_of_result %0[0] : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_extract_slice_into_forall %1 into %0 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } + } +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 0)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (s0, d0)> + +// CHECK-LABEL: func @fuse_dynamic_extract_slice_into_forall +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor +// CHECK-SAME: %[[SIZE0:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[SIZE1:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[EXTRACT_SIZE0:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[EXTRACT_SIZE1:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[STEP0:[A-Za-z0-9]+]]: index +// CHECK-SAME: %[[STEP1:[A-Za-z0-9]+]]: index + +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[SIZE0]], %[[SIZE1]]) : tensor +// CHECK-DAG: %[[SLICED_OUT:.+]] = tensor.extract_slice %[[EMPTY]][0, 0] [%[[EXTRACT_SIZE0]], %[[EXTRACT_SIZE1]]] [1, 1] : tensor to tensor +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall (%[[IDX0:.+]], %[[IDX1:.+]]) = (0, 0) to (%[[SIZE0]], %[[SIZE1]]) step (%[[STEP0]], %[[STEP1]]) +// CHECK-SAME: shared_outs(%[[SLICED_BBARG:.+]] = %[[SLICED_OUT]]) -> (tensor) { + +// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]]{{.*}}[%[[IDX0]], %[[IDX1]]] [%[[STEP0]], %[[STEP1]]] [1, 1] : tensor to tensor +// CHECK-DAG: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[EMPTY]]{{.*}}[%[[IDX0]], %[[IDX1]]] [%[[STEP0]], %[[STEP1]]] [1, 1] : tensor to tensor +// CHECK-DAG: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor) outs(%[[OUT_SLICE]] : tensor) -> tensor + +// CHECK-DAG: %[[SIZE_CLAMPED_LOW0:.+]] = affine.max #[[$MAP]](%[[IDX0]])[%[[EXTRACT_SIZE0]]] +// CHECK-DAG: %[[SIZE_CLAMPED_HIGH0:.+]] = affine.min #[[$MAP1]](%[[SIZE_CLAMPED_LOW0]])[%[[STEP0]]] +// CHECK-DAG: %[[SIZE_CLAMPED_LOW1:.+]] = affine.max #[[$MAP]](%[[IDX1]])[%[[EXTRACT_SIZE1]]] +// CHECK-DAG: %[[SIZE_CLAMPED_HIGH1:.+]] = affine.min #[[$MAP1]](%[[SIZE_CLAMPED_LOW1]])[%[[STEP1]]] + +// CHECK: %[[SLICED_COPY:.+]] = tensor.extract_slice %[[COPY]][0, 0] [%[[SIZE_CLAMPED_HIGH0]], %[[SIZE_CLAMPED_HIGH1]]] [1, 1] : tensor to tensor +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[SLICED_COPY]] into %[[SLICED_BBARG]] +// CHECK-SAME: [%[[IDX0]], %[[IDX1]]] [%[[SIZE_CLAMPED_HIGH0]], %[[SIZE_CLAMPED_HIGH1]]] [1, 1] : tensor into tensor +// CHECK: } +// CHECK: } {mapping = [#gpu.thread, #gpu.thread]} +// CHECK: return %[[FORALL_RESULT]] + +// ----- + +module { + func.func @fuse_rank_reduced_extract_slice_into_forall(%arg0: tensor<4x8xf32>, %arg1: index) -> tensor { + %0 = tensor.empty() : tensor<4x8xf32> + %1 = scf.forall (%arg2, %arg3) = (0, 0) to (4, 8) step (2, 2) shared_outs(%arg4 = %0) -> (tensor<4x8xf32>) { + %extracted_slice_0 = tensor.extract_slice %arg0[%arg2, %arg3] [2, 2] [1, 1] : tensor<4x8xf32> to tensor<2x2xf32> + %extracted_slice_1 = tensor.extract_slice %arg4[%arg2, %arg3] [2, 2] [1, 1] : tensor<4x8xf32> to tensor<2x2xf32> + %2 = linalg.copy ins(%extracted_slice_0 : tensor<2x2xf32>) outs(%extracted_slice_1 : tensor<2x2xf32>) -> tensor<2x2xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %2 into %arg4[%arg2, %arg3] [2, 2] [1, 1] : tensor<2x2xf32> into tensor<4x8xf32> + } + } {mapping = [#gpu.thread, #gpu.thread]} + %extracted_slice = tensor.extract_slice %1[0, 0] [1, %arg1] [1, 1] : tensor<4x8xf32> to tensor + return %extracted_slice : tensor + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_consumers_of_result %0[0] : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_extract_slice_into_forall %1 into %0 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (-d0 + 1, 0)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> (2, d0)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 0)> + +// CHECK-LABEL: func @fuse_rank_reduced_extract_slice_into_forall +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<4x8xf32> +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: index + +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4x8xf32> +// CHECK-DAG: %[[SLICED_OUT:.+]] = tensor.extract_slice %[[EMPTY]][0, 0] [1, %[[ARG1]]] [1, 1] : tensor<4x8xf32> to tensor<1x?xf32> +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall (%[[IDX0:.+]], %[[IDX1:.+]]) = (0, 0) to (4, 8) step (2, 2) +// CHECK-SAME: shared_outs(%[[SLICED_BBARG:.+]] = %[[SLICED_OUT]]) -> (tensor<1x?xf32>) { + +// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]]{{.*}}[%[[IDX0]], %[[IDX1]]] [2, 2] [1, 1] : tensor<4x8xf32> to tensor<2x2xf32> +// CHECK-DAG: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[EMPTY]]{{.*}}[%[[IDX0]], %[[IDX1]]] [2, 2] [1, 1] : tensor<4x8xf32> to tensor<2x2xf32> +// CHECK-DAG: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<2x2xf32>) outs(%[[OUT_SLICE]] : tensor<2x2xf32>) -> tensor<2x2xf32> + +// CHECK-DAG: %[[SIZE_CLAMPED_LOW0:.+]] = affine.max #[[$MAP]](%[[IDX0]]) +// CHECK-DAG: %[[SIZE_CLAMPED_HIGH0:.+]] = affine.min #[[$MAP1]](%[[SIZE_CLAMPED_LOW0]]) +// CHECK-DAG: %[[SIZE_CLAMPED_LOW1:.+]] = affine.max #[[$MAP2]](%[[IDX1]])[%[[ARG1]]] +// CHECK-DAG: %[[SIZE_CLAMPED_HIGH1:.+]] = affine.min #[[$MAP1]](%[[SIZE_CLAMPED_LOW1]]) + +// CHECK: %[[SLICED_COPY:.+]] = tensor.extract_slice %[[COPY]] +// CHECK-SAME: [0, 0] [%[[SIZE_CLAMPED_HIGH0]], %[[SIZE_CLAMPED_HIGH1]]] [1, 1] : tensor<2x2xf32> to tensor +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[SLICED_COPY]] into %[[SLICED_BBARG]] +// CHECK-SAME: [%[[IDX0]], %[[IDX1]]] [%[[SIZE_CLAMPED_HIGH0]], %[[SIZE_CLAMPED_HIGH1]]] [1, 1] : tensor into tensor<1x?xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread, #gpu.thread]} +// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[FORALL_RESULT]] {{\[}}[0, 1]] : tensor<1x?xf32> into tensor +// CHECK: return %[[COLLAPSE]] + +// ----- + +#map = affine_map<(d0) -> (d0 * 2)> +module { + func.func @fuse_extract_slice_into_rank_reduced_forall_slices(%arg0: tensor<4x8xf32>, %size1: index) -> tensor<4x?xf32> { + %0 = tensor.empty() : tensor<4x8xf32> + %1 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0) -> (tensor<4x8xf32>) { + %2 = affine.apply #map(%arg2) + %extracted_slice_0 = tensor.extract_slice %arg0[0, %2] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32> + %extracted_slice_1 = tensor.extract_slice %arg3[0, %2] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32> + %3 = linalg.copy ins(%extracted_slice_0 : tensor<2xf32>) outs(%extracted_slice_1 : tensor<2xf32>) -> tensor<2xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[0, %2] [1, 2] [1, 1] : tensor<2xf32> into tensor<4x8xf32> + } + } {mapping = [#gpu.thread]} + %extracted_slice = tensor.extract_slice %1[0, 0] [4, %size1] [1, 1] : tensor<4x8xf32> to tensor<4x?xf32> + return %extracted_slice : tensor<4x?xf32> + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %producer = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer = transform.get_consumers_of_result %producer[0] : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_extract_slice_into_forall %consumer into %producer + : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 * 2)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (d0 * -2 + s0, 0)> +// CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0) -> (2, d0)> + +// CHECK-LABEL: func @fuse_extract_slice_into_rank_reduced_forall_slices +// CHECK-SAME: %[[ARG0:[A-Za-z0-9]+]]: tensor<4x8xf32> +// CHECK-SAME: %[[ARG1:[A-Za-z0-9]+]]: index + +// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<4x8xf32> +// CHECK-DAG: %[[SLICED_OUT:.+]] = tensor.extract_slice %[[EMPTY]][0, 0] [4, %[[ARG1]]] [1, 1] : tensor<4x8xf32> to tensor<4x?xf32> +// CHECK: %[[FORALL_RESULT:.+]] = scf.forall (%[[IDX:.+]]) in (4) +// CHECK-SAME: shared_outs(%[[SLICED_BBARG:.+]] = %[[SLICED_OUT]]) -> (tensor<4x?xf32>) { + +// CHECK-DAG: %[[SLICE_IDX:.+]] = affine.apply #[[$MAP]](%[[IDX]]) +// CHECK-DAG: %[[IN_SLICE:.+]] = tensor.extract_slice %[[ARG0]]{{.*}}[0, %[[SLICE_IDX]]] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32> +// CHECK-DAG: %[[OUT_SLICE:.+]] = tensor.extract_slice %[[EMPTY]]{{.*}}[0, %[[SLICE_IDX]]] [1, 2] [1, 1] : tensor<4x8xf32> to tensor<2xf32> +// CHECK-DAG: %[[COPY:.+]] = linalg.copy ins(%[[IN_SLICE]] : tensor<2xf32>) outs(%[[OUT_SLICE]] : tensor<2xf32>) -> tensor<2xf32> + +// CHECK-DAG: %[[SIZE_CLAMPED_LOW:.+]] = affine.max #[[$MAP1]](%[[IDX]])[%[[ARG1]]] +// CHECK-DAG: %[[SIZE_CLAMPED_HIGH:.+]] = affine.min #[[$MAP2]](%[[SIZE_CLAMPED_LOW]]) + +// CHECK: %[[SLICED_COPY:.+]] = tensor.extract_slice %[[COPY]][0] [%[[SIZE_CLAMPED_HIGH]]] [1] : tensor<2xf32> to tensor +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[SLICED_COPY]] into %[[SLICED_BBARG]] +// CHECK-SAME: [0, %[[SLICE_IDX]]] [1, %[[SIZE_CLAMPED_HIGH]]] [1, 1] : tensor into tensor<4x?xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: return %[[FORALL_RESULT]] + +// ----- + +#map = affine_map<(d0) -> (d0 * 2)> +module { + func.func @fuse_extract_slice_into_multi_result_forall(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>, %arg2: index) -> (tensor, tensor<8xf32>) { + %0 = tensor.empty() : tensor<8xf32> + %1 = tensor.empty() : tensor<8xf32> + %2:2 = scf.forall (%arg3) in (4) shared_outs(%arg4 = %0, %arg5 = %1) -> (tensor<8xf32>, tensor<8xf32>) { + %3 = affine.apply #map(%arg3) + %extracted_slice_0 = tensor.extract_slice %arg0[%3] [2] [1] : tensor<8xf32> to tensor<2xf32> + %extracted_slice_1 = tensor.extract_slice %arg4[%3] [2] [1] : tensor<8xf32> to tensor<2xf32> + %extracted_slice_2 = tensor.extract_slice %arg1[%3] [2] [1] : tensor<8xf32> to tensor<2xf32> + %extracted_slice_3 = tensor.extract_slice %arg5[%3] [2] [1] : tensor<8xf32> to tensor<2xf32> + %4 = linalg.copy ins(%extracted_slice_0 : tensor<2xf32>) outs(%extracted_slice_1 : tensor<2xf32>) -> tensor<2xf32> + %5 = linalg.copy ins(%extracted_slice_2 : tensor<2xf32>) outs(%extracted_slice_3 : tensor<2xf32>) -> tensor<2xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %4 into %arg4[%3] [2] [1] : tensor<2xf32> into tensor<8xf32> + tensor.parallel_insert_slice %5 into %arg5[%3] [2] [1] : tensor<2xf32> into tensor<8xf32> + } + } {mapping = [#gpu.thread]} + %extracted_slice = tensor.extract_slice %2#0[0] [%arg2] [1] : tensor<8xf32> to tensor + return %extracted_slice, %2#1 : tensor, tensor<8xf32> + } +} +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_consumers_of_result %0[0] : (!transform.any_op) -> !transform.any_op + %2 = transform.iree.fuse_extract_slice_into_forall %1 into %0 : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} + +// CHECK-LABEL: func @fuse_extract_slice_into_multi_result_forall + +// CHECK: %[[FORALL_RESULT:.+]]:2 = scf.forall {{.*}} -> (tensor, tensor<8xf32>) { +// CHECK: scf.forall.in_parallel { +// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor into tensor +// CHECK-DAG: tensor.parallel_insert_slice {{.*}} : tensor<2xf32> into tensor<8xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread]} +// CHECK: return %[[FORALL_RESULT]]#0, %[[FORALL_RESULT]]#1 + +// ----- + +#map = affine_map<(d0) -> (d0 * 2)> +module { + func.func @no_fuse_extract_slice_with_offset(%arg0: tensor<8xf32>, %arg1: index) -> tensor { + %0 = tensor.empty() : tensor<8xf32> + %1 = scf.forall (%arg2) in (4) shared_outs(%arg3 = %0) -> (tensor<8xf32>) { + %2 = affine.apply #map(%arg2) + %extracted_slice_0 = tensor.extract_slice %arg0[%2] [2] [1] : tensor<8xf32> to tensor<2xf32> + %extracted_slice_1 = tensor.extract_slice %arg3[%2] [2] [1] : tensor<8xf32> to tensor<2xf32> + %3 = linalg.copy ins(%extracted_slice_0 : tensor<2xf32>) outs(%extracted_slice_1 : tensor<2xf32>) -> tensor<2xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %3 into %arg3[%2] [2] [1] : tensor<2xf32> into tensor<8xf32> + } + } {mapping = [#gpu.thread]} + %extracted_slice = tensor.extract_slice %1[2] [%arg1] [1] : tensor<8xf32> to tensor + return %extracted_slice : tensor + } +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %producer = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %consumer = transform.get_consumers_of_result %producer[0] : (!transform.any_op) -> !transform.any_op + // expected-error@+1 {{failed to fuse extract_slice op}} + %2 = transform.iree.fuse_extract_slice_into_forall %consumer into %producer + : (!transform.any_op, !transform.any_op) -> !transform.any_op + transform.yield + } +} diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index e30a14d8f9739..cc2c7e5accabc 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -12,12 +12,10 @@ #include "iree/compiler/Codegen/Utils/GPUUtils.h" #include "iree/compiler/Codegen/Utils/MarkerUtils.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/Support/LogicalResult.h" -#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -518,6 +516,258 @@ fuseCollapseShapeIntoProducerForall(RewriterBase &rewriter, return newForallOp; } +/// Return whether the `parallelInsertOp` can be clamped along the sliced +/// dimensions of `extractSliceOp`. The dimensions of the extractSliceOp source +/// are expected to match the dimensions of the parallelInsertOp destination. +/// This function checks that the parallelInsertOp is not rank reducing along +/// any of the sliced dimensions of the extractSliceOp. +static LogicalResult canClampParallelInsertSlice( + RewriterBase &rewriter, tensor::ParallelInsertSliceOp parallelInsertOp, + tensor::ExtractSliceOp extractSliceOp, + llvm::SmallDenseSet insertRankReductionMask) { + // Find the dimensions that are sliced by the extractSliceOp + llvm::SmallDenseSet slicedDims; + ArrayRef sliceStaticSizes = extractSliceOp.getStaticSizes(); + ArrayRef sliceSourceSizes = + extractSliceOp.getSourceType().getShape(); + for (int dim = 0; dim < sliceStaticSizes.size(); ++dim) { + if (ShapedType::isDynamic(sliceStaticSizes[dim]) || + sliceStaticSizes[dim] != sliceSourceSizes[dim]) { + slicedDims.insert(dim); + } + } + for (int dim = 0; dim < parallelInsertOp.getDestType().getRank(); ++dim) { + if (insertRankReductionMask.contains(dim) && slicedDims.contains(dim)) { + return rewriter.notifyMatchFailure( + parallelInsertOp, "parallel insert reduces sliced dimensions"); + } + } + return success(); +} + +/// Clamps the source of a parallel_insert_slice op to fit within the +/// `upperBoundSizes`. This function computes the upper bound sizes, and creates +/// an extract slice op on the parallel insert source, which is then used in a +/// new parallel insert slice to replace the old one. This function assumes that +/// the parallel insert op passes `canClampParallelInsertSlice` precondition. +static FailureOr +clampParallelInsertSliceOp(RewriterBase &rewriter, + tensor::ParallelInsertSliceOp parallelInsertOp, + SmallVector upperBoundSizes) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(parallelInsertOp.getParallelCombiningParent()); + Location loc = parallelInsertOp.getParallelCombiningParent()->getLoc(); + + // Clamp the parallel_insert_slice sizes to fit within the full result tensor. + SmallVector offsets = parallelInsertOp.getMixedOffsets(); + SmallVector sizes = parallelInsertOp.getMixedSizes(); + SmallVector clampedSizes; + for (auto [offset, size, ub] : + llvm::zip_equal(offsets, sizes, upperBoundSizes)) { + AffineExpr d0, d1, d2; + MLIRContext *ctx = rewriter.getContext(); + bindDims(ctx, d0, d1, d2); + auto lbClampMap = AffineMap::get(3, 0, {d0 - d1, d2}, ctx); + auto ubClampMap = rewriter.getMultiDimIdentityMap(2); + OpFoldResult lbClamped = affine::makeComposedFoldedAffineMax( + rewriter, loc, lbClampMap, {ub, offset, rewriter.getIndexAttr(0)}); + OpFoldResult ubClamped = affine::makeComposedFoldedAffineMin( + rewriter, loc, ubClampMap, {lbClamped, size}); + clampedSizes.push_back(ubClamped); + } + + // Compute the clamped type. This could be rank reduced, but rank reduced + // dimensions will never be potentially zero by construction. The earlier + // matchers ensure that all sliceable users are not rank reduced along a + // dimensions that is being sliced by the loop consumer. + llvm::SmallDenseSet rankReductionMask = + computeRankReductionMask(parallelInsertOp.getStaticSizes(), + parallelInsertOp.getSourceType().getShape(), + /*matchDynamic=*/true) + .value(); + SmallVector clampedShape; + SmallVector rankReducedClampedSizes; + SmallVector d; + for (auto [idx, clampedSize] : llvm::enumerate(clampedSizes)) { + if (rankReductionMask.contains(idx)) { + continue; + } + dispatchIndexOpFoldResult(clampedSize, d, clampedShape); + rankReducedClampedSizes.push_back(clampedSize); + } + RankedTensorType clampedType = + parallelInsertOp.getSourceType().clone(clampedShape); + // Create an extract_slice to extract the correct size from the parallel + // insert source. + SmallVector zeros(clampedType.getRank(), + rewriter.getIndexAttr(0)); + SmallVector ones(clampedType.getRank(), + rewriter.getIndexAttr(1)); + Operation *combiningOp = + parallelInsertOp.getParallelCombiningParent().getOperation(); + rewriter.setInsertionPoint(combiningOp); + loc = combiningOp->getLoc(); + auto extractOp = rewriter.create( + loc, clampedType, parallelInsertOp.getSource(), zeros, + rankReducedClampedSizes, ones); + + // Replace the parallel insert op with the clamped version, and return the + // new parallel insert slice. + rewriter.setInsertionPoint(parallelInsertOp); + loc = parallelInsertOp->getLoc(); + return rewriter.replaceOpWithNewOp( + parallelInsertOp, extractOp.getResult(), parallelInsertOp.getDest(), + parallelInsertOp.getMixedOffsets(), clampedSizes, + parallelInsertOp.getMixedStrides()); +} + +FailureOr +fuseExtractSliceIntoProducerForall(RewriterBase &rewriter, + scf::ForallOp forallOp, + tensor::ExtractSliceOp extractSliceOp) { + auto forallResult = cast(extractSliceOp.getSource()); + if (!forallResult.hasOneUse()) { + return rewriter.notifyMatchFailure(forallOp, + "forall result has multiple uses"); + } + BlockArgument initBbarg = + forallOp.getRegionIterArgs()[forallResult.getResultNumber()]; + SmallVector parallelInsertOps = + forallOp.getCombiningOps(initBbarg); + if (parallelInsertOps.size() != 1) { + return rewriter.notifyMatchFailure( + forallOp, "Expected a single parallel_insert_slice"); + } + + auto parallelInsertOp = + dyn_cast(parallelInsertOps.front()); + if (!parallelInsertOp) { + return rewriter.notifyMatchFailure( + forallOp, "Expected parallel_insert_slice combining op"); + } + + // Only zero offset extract_slice ops are supported. + if (!areAllConstantIntValue(extractSliceOp.getMixedOffsets(), 0)) { + return rewriter.notifyMatchFailure(forallOp, + "extract_slice has non-zero offsets"); + } + + // The extract_slice index operands must dominate the forall loop in order + // to extract a slice of the init operand later. + DominanceInfo domInfo; + int64_t indexOperandStartIdx = + extractSliceOp.getOffsetSizeAndStrideStartOperandIndex(); + SmallVector indexOperands(extractSliceOp->getOperands().begin() + + indexOperandStartIdx, + extractSliceOp->getOperands().end()); + if (!llvm::all_of(indexOperands, + [&](Value v) { return domInfo.dominates(v, forallOp); })) { + return rewriter.notifyMatchFailure( + extractSliceOp, + "Extract slice index operands do not dominate the forall op"); + } + + // Compute the rank reduction mask of the extract_slice for resolving rank + // reduction at the end. For rank reducing slices, the extract_slice is + // fused into the loop as a non rank reducing slice, and then a collapse + // shape is added on the result of the loop. This simplifies the logic in + // this pattern, and other patterns for collapse shape fusion can then fuse + // this collapse shape into the loop if needed. + auto maybeRankReductionMask = computeRankReductionMask( + extractSliceOp.getStaticSizes(), extractSliceOp.getType().getShape(), + /*matchDynamic=*/true); + if (!maybeRankReductionMask) { + return rewriter.notifyMatchFailure(extractSliceOp, + "Could not compute rank reduction mask"); + } + + std::optional> + maybeInsertRankReductionMask = + computeRankReductionMask(parallelInsertOp.getStaticSizes(), + parallelInsertOp.getSourceType().getShape(), + /*matchDynamic=*/true); + if (!maybeInsertRankReductionMask) { + return rewriter.notifyMatchFailure(parallelInsertOp, + "Could not compute rank reduction mask"); + } + llvm::SmallDenseSet insertRankReductionMask = + maybeInsertRankReductionMask.value(); + + // Verify that the parallelInsertOp can be clamped to the sizes of the + // extractSliceOp. + if (failed(canClampParallelInsertSlice(rewriter, parallelInsertOp, + extractSliceOp, + insertRankReductionMask))) { + return failure(); + } + int64_t resultIdx = forallResult.getResultNumber(); + + // Clamp the parallel insert slice source to fit within the extracted slice. + SmallVector newInitSizes = extractSliceOp.getMixedSizes(); + FailureOr maybeClampedParallelInsertSliceOp = + clampParallelInsertSliceOp(rewriter, parallelInsertOp, newInitSizes); + if (failed(maybeClampedParallelInsertSliceOp)) { + return failure(); + } + tensor::ParallelInsertSliceOp clampedParallelInsertSliceOp = + maybeClampedParallelInsertSliceOp.value(); + + // Now replace users of the forall loop init argument with the output operand + // from outside the loop. Do not replace the clamped parallel insert dest. + Value forallOutput = forallOp.getOutputs()[forallResult.getResultNumber()]; + rewriter.replaceUsesWithIf(initBbarg, forallOutput, [&](OpOperand &operand) { + return operand != clampedParallelInsertSliceOp.getDestMutable(); + }); + + // Clone the extract_slice, and replace the source with the forall init + // operand. + Value forallInit = forallOp.getOutputs()[resultIdx]; + rewriter.setInsertionPoint(forallOp); + auto extractedInit = rewriter.create( + forallOp->getLoc(), forallInit, extractSliceOp.getMixedOffsets(), + extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); + + // Clone the forall op with the extracted init operand to replace the + // original forall op. + Location loc = forallOp->getLoc(); + rewriter.setInsertionPoint(forallOp); + SmallVector newForallOutputs(forallOp.getOutputs()); + newForallOutputs[resultIdx] = extractedInit.getResult(); + + scf::ForallOp newForallOp = rewriter.create( + loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(), + forallOp.getMixedStep(), newForallOutputs, forallOp.getMappingAttr()); + + SmallVector argReplacements(newForallOp.getInductionVars()); + argReplacements.append(newForallOp.getRegionIterArgs().begin(), + newForallOp.getRegionIterArgs().end()); + newForallOp.getTerminator()->erase(); + rewriter.mergeBlocks(forallOp.getBody(), newForallOp.getBody(), + argReplacements); + + // Create a collapse_shape to handle rank reduction. + Value extractedResult = newForallOp->getResult(resultIdx); + auto forallResultType = cast(extractedResult.getType()); + SmallVector reassociations; + ReassociationIndices reassociation; + for (int i = 0; i < forallResultType.getRank(); ++i) { + if (maybeRankReductionMask->contains(i)) { + reassociation.push_back(i); + continue; + } + reassociation.push_back(i); + reassociations.push_back(reassociation); + reassociation = {}; + } + auto collapseShape = rewriter.create( + extractSliceOp->getLoc(), extractedResult, reassociations); + + // Replace forall and extract_slice ops with the new operations. + rewriter.replaceAllOpUsesWith(extractSliceOp, collapseShape); + rewriter.replaceOp(forallOp, newForallOp); + return newForallOp; +} + //===----------------------------------------------------------------------===// // MultiMmaOp Lowering //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h index 444762b09ed75..5de98aae0c79c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h @@ -99,6 +99,51 @@ fuseCollapseShapeIntoProducerForall(RewriterBase &rewriter, scf::ForallOp forallOp, tensor::CollapseShapeOp collapseOp); +/// Function to fuse an extract slice op into a forall op producer. This rewrite +/// effectively bubbles the extract_slice op up through the forall output +/// operand, and the block argument inside the forall becomes the size of the +/// slice. The parallel_insert_slice user of the init block argument will have +/// its source clamped to fit into the sliced destination, and all other uses +/// of the block argument will be replaced with the value of the output operand +/// for the forall outside of the loop body. +/// *NOTE: This can create dynamically zero sized tensors inside the forall +/// body when the source of the parallel_insert_slice is clamped. +/// +/// The following example illustrates a simple case of this transformation: +/// ``` +/// %forall = scf.forall ... shared_outs(%arg = %dest) -> tensor<16xf32> { +/// %user = "some_user" %arg +/// ... +/// scf.in_parallel { +/// tensor.parallel_insert_slice %val into %arg ... +/// tensor<4xf32> into tensor<16xf32> +/// } +/// } +/// %extract = tensor.extract_slice %forall ... +/// tensor<16xf32> into tensor +/// ``` +/// After the transformation this would become: +/// ``` +/// %extract = tensor.extract_slice %dest ... +/// tensor<16xf32> into tensor +/// %forall = scf.forall ... shared_outs(%arg = %extract) -> tensor { +/// // The user now has the dest from outside the loop as its operand. +/// %user = "some_user" %dest +/// ... +/// // `%clamped_val` can be dynamically zero sized. +/// %clamped_val = tensor.extract_slice %val ... +/// tensor<4xf32> to tensor +/// scf.in_parallel { +/// tensor.parallel_insert_slice %clamped_val into %arg ... +/// tensor into tensor +/// } +/// } +/// ``` +FailureOr +fuseExtractSliceIntoProducerForall(RewriterBase &rewriter, + scf::ForallOp forallOp, + tensor::ExtractSliceOp extractSliceOp); + // Helper to convert a contraction-like linalg op to an iree_gpu.multi_mma. FailureOr convertContractionToMultiMma(RewriterBase &rewriter, linalg::LinalgOp linalgOp,