-
Notifications
You must be signed in to change notification settings - Fork 12.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add pattern to bubble up tensor.extract_slice #126898
base: main
Are you sure you want to change the base?
[MLIR] Add pattern to bubble up tensor.extract_slice #126898
Conversation
Add a pattern that bubbles up tensor.extract_slice through tensor.expand_shape. This pattern enables tiling and fusing op chains which contain tensor.expand_shape if added as a cleanup pattern of tile and fuse utility. Without this pattern that would not be possible, as tensor.expand_shape does not implement the tiling interface. In addition, registering this pattern as a cleanup pattern for transform.structured.fuse. The pattren was first implement in IREE project by Quinn Dawkins and is being upstreamed.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: ofri frishman (ofri-frishman) ChangesAdd a pattern that bubbles up tensor.extract_slice through tensor.expand_shape. Full diff: https://github.com/llvm/llvm-project/pull/126898.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1..dc4558a605a59 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -58,6 +58,12 @@ void populateFoldTensorSubsetIntoVectorTransferPatterns(
void populateMergeConsecutiveInsertExtractSlicePatterns(
RewritePatternSet &patterns);
+/// Appends patterns that are used to bubble up tensor.extract slice op above
+/// its producer. When used as cleanup patterns of tile and fuse, enables fusing
+/// the producer with the consumer even if the producer does not implement the
+/// tiling interface.
+void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
+
/// Populates `patterns` with patterns that drop redundant tensor.insert_slice
/// rank expansions.
void populateDropRedundantInsertSliceRankExpansionPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 51d1df52598c7..5146bebe0108e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
RewritePatternSet patterns(context);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
+ tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSlice.cpp
new file mode 100644
index 0000000000000..a0d3c6d25bbe8
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSlice.cpp
@@ -0,0 +1,207 @@
+//===- BubbleUpExtractSlice.cpp ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Swap a `tensor.extract_slice` with the producer of the source in some cases
+// where that is valid. When used as cleanup patterns of tile and fuse, enables
+// fusing the producer with the consumer even if the producer does not implement
+// the tiling interface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+/// Converts `tensor.extract_slice(tensor.expand_shape)` to
+/// `tensor.expand_shape(tensor.extract_slice)`.
+/// For this transformation to be possible, the slice must be fully contiguous
+/// within each reassociation group of the expand_shape. If the transformation
+/// is not possible, or if the slice is rank reducting, the function returns
+/// failure.
+///
+/// Example:
+/// ```
+/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
+/// tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
+/// %slice = tensor.extract_slice %reshape ...
+/// tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+///
+/// // The transformation is possible because each reassociation group has a
+/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
+/// // After the transformation:
+///
+/// %slice = tensor.extract_slice %in ...
+/// tensor<8x16x32xf32> to tensor<8x5x4xf32>
+/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
+/// tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+/// ```
+static LogicalResult
+swapExpandShapeWithSlice(RewriterBase &rewriter,
+ tensor::ExpandShapeOp expandShapeOp,
+ tensor::ExtractSliceOp sliceOp) {
+ SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+
+ if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+ return rewriter.notifyMatchFailure(sliceOp,
+ "unimplemented: rank reducing slice");
+ }
+
+ // Helper variables and function for accumulating the new offset and length
+ // values.
+ Location loc = expandShapeOp->getLoc();
+ AffineExpr d0, d1, d2;
+ bindDims(rewriter.getContext(), d0, d1, d2);
+ // Multiply two integers.
+ auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+ auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+ return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+ {v1, v2});
+ };
+
+ SmallVector<OpFoldResult> outputShape =
+ getMixedValues(expandShapeOp.getStaticOutputShape(),
+ expandShapeOp.getOutputShape(), rewriter);
+
+ auto isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize,
+ OpFoldResult size) {
+ if (!isConstantIntValue(offset, 0))
+ return false;
+ FailureOr<bool> maybeEqual =
+ ValueBoundsConstraintSet::areEqual(sliceSize, size);
+ return llvm::succeeded(maybeEqual) && maybeEqual.value();
+ };
+
+ // First verify that this is a full slice of the expanded tensor.
+ for (const ReassociationIndices &indices :
+ expandShapeOp.getReassociationIndices()) {
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Find the first expanded dim after the first dim with non-unit extracted
+ // size.
+ for (; i < e; ++i) {
+ if (!isConstantIntValue(sizes[indices[i]], 1)) {
+ // +1 to skip the first non-unit size dim.
+ i++;
+ break;
+ }
+ }
+
+ // Verify that all subsequent dimensions extract the full size of the
+ // source tensor.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+ outputShape[expandedDim])) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "Not a contiguous slice of the expanded tensor.");
+ }
+ }
+ }
+
+ // Compute new offsets, lengths, and strides.
+ SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
+ for (const ReassociationIndices &indices :
+ expandShapeOp.getReassociationIndices()) {
+ OpFoldResult newSize = rewriter.getIndexAttr(1);
+ SmallVector<OpFoldResult> basis, delinOffsets;
+
+ int64_t i = 0;
+ int64_t e = indices.size();
+ // Offset = cumulative product of leading unit extracted dims.
+ for (; i < e; ++i) {
+ int64_t expandedDim = indices[i];
+ if (!isConstantIntValue(sizes[expandedDim], 1))
+ break;
+
+ basis.push_back(outputShape[expandedDim]);
+ delinOffsets.push_back(offsets[expandedDim]);
+ }
+
+ if (i != e) {
+ int64_t expandedDim = indices[i];
+ basis.push_back(outputShape[expandedDim]);
+ delinOffsets.push_back(offsets[expandedDim]);
+ newSize = sizes[expandedDim];
+ i++;
+ }
+
+ for (; i < e; ++i) {
+ OpFoldResult fullSize = outputShape[indices[i]];
+ basis.push_back(fullSize);
+ delinOffsets.push_back(rewriter.getIndexAttr(0));
+ newSize = mul(newSize, fullSize);
+ }
+ SmallVector<Value> offsetVals =
+ llvm::map_to_vector(delinOffsets, [&](OpFoldResult ofr) {
+ return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+ });
+ OpFoldResult newOffset = rewriter
+ .create<affine::AffineLinearizeIndexOp>(
+ loc, offsetVals, basis, /*disjoint=*/true)
+ .getResult();
+ newOffsets.push_back(newOffset);
+ newLengths.push_back(newSize);
+
+ // Only unit stride supported.
+ newStrides.push_back(rewriter.getIndexAttr(1));
+ }
+
+ // The shape of the result can be obtained from the sizes passed in.
+ SmallVector<Value> dynDims;
+ SmallVector<int64_t> shape;
+ dispatchIndexOpFoldResults(sizes, dynDims, shape);
+ RankedTensorType resultType = RankedTensorType::get(
+ shape, expandShapeOp.getResultType().getElementType());
+
+ // Create a new ExtractSliceOp and ExpandShapeOp.
+ Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+ loc, expandShapeOp.getSrc(), newOffsets, newLengths, newStrides);
+ auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
+ loc, resultType, newSliceOp, expandShapeOp.getReassociationIndices(),
+ sizes);
+ rewriter.replaceOp(sliceOp, newExpandShapeOp);
+ return success();
+}
+
+namespace {
+
+struct SwapExpandShapeWithSlicePattern
+ : public OpRewritePattern<tensor::ExtractSliceOp> {
+ using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ auto expandOp = sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+ if (!expandOp) {
+ return failure();
+ }
+
+ if (!sliceOp.hasUnitStride()) {
+ return rewriter.notifyMatchFailure(sliceOp,
+ "unsupported: non-unit stride");
+ }
+
+ return swapExpandShapeWithSlice(rewriter, expandOp, sliceOp);
+ }
+};
+
+} // namespace
+
+void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
+ RewritePatternSet &patterns) {
+ patterns.add<SwapExpandShapeWithSlicePattern>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index cc6275fee671a..634cc93a08352 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
RewriteAsConstant.cpp
SwapExtractSliceWithProducerPatterns.cpp
SubsetInsertionOpInterfaceImpl.cpp
+ BubbleUpExtractSlice.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index ac1ca9319d335..22796611c5934 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -278,3 +278,141 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice
+// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[Z]]] by (2, 3, 10)
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5]
+// CHECK: linalg.exp ins(%[[EXPAND]]
+func.func @swap_expand_shape_with_extract_slice(%0: tensor<60xf32>) -> tensor<2x3x10xf32> {
+ %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+ %empty = tensor.empty() : tensor<2x3x10xf32>
+ %exp = linalg.exp ins(%expand : tensor<2x3x10xf32>) outs(%empty : tensor<2x3x10xf32>) -> tensor<2x3x10xf32>
+ return %exp : tensor<2x3x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true :
+ (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_full_inner_dim
+// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]]{{.*}} by (3, 4, 10)
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 2, 10]
+// CHECK: linalg.exp ins(%[[EXPAND]]
+func.func @swap_expand_shape_with_extract_slice_full_inner_dim(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
+ %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
+ %empty = tensor.empty() : tensor<3x4x10xf32>
+ %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
+ return %exp : tensor<3x4x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true :
+ (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_full_inner_dim
+// CHECK: tensor.expand_shape
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: scf.for
+// CHECK: linalg.exp
+func.func @swap_expand_shape_with_extract_slice_full_inner_dim(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
+ %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
+ %empty = tensor.empty() : tensor<3x4x10xf32>
+ %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
+ return %exp : tensor<3x4x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true :
+ (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: scf.for %[[W:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: %[[LINEAR_IDX0:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[C0]]] by (3, 4, 10)
+// CHECK: %[[LINEAR_IDX1:.+]] = affine.linearize_index disjoint [%[[Z]], %[[W]]] by (7, 8)
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX0]], %[[LINEAR_IDX1]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32>
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
+// CHECK: linalg.exp ins(%[[EXPAND]]
+module {
+ func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims(%0: tensor<120x56xf32>) -> tensor<3x4x10x7x8xf32> {
+ %expand = tensor.expand_shape %0 [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32>
+ %empty = tensor.empty() : tensor<3x4x10x7x8xf32>
+ %exp = linalg.exp ins(%expand : tensor<3x4x10x7x8xf32>) outs(%empty : tensor<3x4x10x7x8xf32>) -> tensor<3x4x10x7x8xf32>
+ return %exp : tensor<3x4x10x7x8xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true :
+ (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op)
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+// CHECK: %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], {{.*}} by (8, 32)
+// CHECK: %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[0, 0, %[[LINEAR_IDX]]] [1, 1800, 32] [1, 1, 1] : tensor<1x1800x256xf32> to tensor<1x1800x32xf32>
+// CHECK: %[[ABS:.+]] = linalg.abs ins(%[[SLICE]]
+// CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ABS]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 1800, 1, 32]
+// CHECK: linalg.exp ins(%[[EXPAND]]
+module {
+ func.func @swap_expand_shape_with_extract_slice_and_fuse_with_expand_producer(%0: tensor<1x1800x256xf32>) -> tensor<1x1800x8x32xf32> {
+ %empty1 = tensor.empty() : tensor<1x1800x256xf32>
+ %exp1 = linalg.abs ins(%0 : tensor<1x1800x256xf32>) outs(%empty1 : tensor<1x1800x256xf32>) -> tensor<1x1800x256xf32>
+ %expand = tensor.expand_shape %exp1 [[0], [1], [2, 3]] output_shape [1, 1800, 8, 32] : tensor<1x1800x256xf32> into tensor<1x1800x8x32xf32>
+ %empty2 = tensor.empty() : tensor<1x1800x8x32xf32>
+ %exp2 = linalg.exp ins(%expand : tensor<1x1800x8x32xf32>) outs(%empty2 : tensor<1x1800x8x32xf32>) -> tensor<1x1800x8x32xf32>
+ return %exp2 : tensor<1x1800x8x32xf32>
+ }
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+ %transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true :
+ (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+ transform.yield
+ }
+}
+
+
+
+
|
There are very similar patterns in mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp. How are they related? Should the new pattern also live in that file? |
I agree that the patterns in mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp are similar. The pattern I added is meant to swap between expand_shape and extract_slice which when used as a cleanup pattern for tile and fuse utility enables adding the expand_shape into a loop nest even though it does not implement the tiling interface. |
@@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, | |||
RewritePatternSet patterns(context); | |||
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context); | |||
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); | |||
tensor::populateBubbleUpExtractSliceOpPatterns(patterns); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a strong opinion on this transform op, but I don't know if I would bucket this pattern as "cleanup." If there are no objections this is fine with me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this transformation is "unblocking" tiling. As such, it feels much more impactful than merely "cleaning up".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that the pattern "unblocks" fusion of operations into a single tiling loop nest.
I think that also other patterns used for cleanup assist with that.
From what I understand, the concept of adding a possibility to cleanup between fusion steps was added precisely for this reason. As part of the thread in https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155 the idea to add such a step was proposed, and it was first implement in #109554.
The lit tests added in that PR show cases where previously fusion would have got stuck, but after the change they do not. So, I thought it would make sense to add this pattern to the cleanup patterns as well.
But perhaps I misunderstood something about the nature and use of the cleanup patterns.
How otherwise would you suggest to add the ability to apply this pattern between fusion steps of FuseOp
?
Ok, make sense to me. Thanks for the analysis! Discoverability has been a big headache for me: There are many useful patterns, but its really hard to find whether a certain pattern already exists. |
Let me suggest a partial solution. We should be able to identify relevant transformations by simply "grep"-ing for "interesting" Ops in the test directory. For this to work, we need to make sure that:
As a stepping stone, @ofri-frishman, please make sure that the new pattern is tested in isolation. As for the right location, there seems to be 2 good candidates:
This pattern feels like "bubbling up" to me, though "swap" might be more fitting actually. Just to avoid bike-shedding, I suggest "BubbleUpExtractSlice.cpp", but lets note (in comments) that this is basically a "swap". We can re-visit later. |
Just a general question - currently in github I see that I am defined as a contributor to the llvm project and cannot add reviewers to review the PR and will not be able to land the PR once approved. Do you know who I can ask to give me the relevant permissions? |
Hey, https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access should have the relevant information. |
Re: landing the PR. I don't know if the policy has changed, but typically for first time contributors a reviewer can land the PR once approved/ready. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing this! I wrote most of the code here so it might be worth getting final approval from someone else. +1 to adding a separate testing op from transform.structured.fuse
.
PatternRewriter &rewriter) const override { | ||
auto expandOp = sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>(); | ||
if (!expandOp) { | ||
return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return failure(); | |
return rewriter.notifyMatchFailure(sliceOp, | |
"slice source not produced by expand_shape"); |
auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>( | ||
loc, resultType, newSliceOp, expandShapeOp.getReassociationIndices(), | ||
sizes); | ||
rewriter.replaceOp(sliceOp, newExpandShapeOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can use replaceOpWithNewOp
I was actually wondering how best to go about that and wasn't aware of the option to add test ops, I was thinking about adding a test pass. You think the best way to have a test that isolates the pattern is via a test op (I guess you mean a transform op)? If so, any chance you could point me to an example of such an op? |
Add a pattern that bubbles up tensor.extract_slice through tensor.expand_shape.
This pattern enables tiling and fusing op chains which contain tensor.expand_shape if added as a cleanup pattern of tile and fuse utility.
Without this pattern that would not be possible, as tensor.expand_shape does not implement the tiling interface. In addition, registering this pattern as a cleanup pattern for transform.structured.fuse.
The pattren was first implement in IREE project by Quinn Dawkins and is being upstreamed.