Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add pattern to bubble up tensor.extract_slice #126898

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ofri-frishman
Copy link
Contributor

Add a pattern that bubbles up tensor.extract_slice through tensor.expand_shape.
This pattern enables tiling and fusing op chains which contain tensor.expand_shape if added as a cleanup pattern of tile and fuse utility.
Without this pattern that would not be possible, as tensor.expand_shape does not implement the tiling interface. In addition, registering this pattern as a cleanup pattern for transform.structured.fuse.
The pattren was first implement in IREE project by Quinn Dawkins and is being upstreamed.

Add a pattern that bubbles up tensor.extract_slice through
tensor.expand_shape.
This pattern enables tiling and fusing op chains which contain
tensor.expand_shape if added as a cleanup pattern of tile and fuse
utility.
Without this pattern that would not be possible, as
tensor.expand_shape does not implement the tiling interface.
In addition, registering this pattern as a cleanup pattern for
transform.structured.fuse.
The pattren was first implement in IREE project by
Quinn Dawkins and is being upstreamed.
@llvmbot
Copy link
Member

llvmbot commented Feb 12, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-tensor

@llvm/pr-subscribers-mlir

Author: ofri frishman (ofri-frishman)

Changes

Add a pattern that bubbles up tensor.extract_slice through tensor.expand_shape.
This pattern enables tiling and fusing op chains which contain tensor.expand_shape if added as a cleanup pattern of tile and fuse utility.
Without this pattern that would not be possible, as tensor.expand_shape does not implement the tiling interface. In addition, registering this pattern as a cleanup pattern for transform.structured.fuse.
The pattren was first implement in IREE project by Quinn Dawkins and is being upstreamed.


Full diff: https://github.com/llvm/llvm-project/pull/126898.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+6)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+1)
  • (added) mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSlice.cpp (+207)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/test/Dialect/Linalg/transform-op-fuse.mlir (+138)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1..dc4558a605a59 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -58,6 +58,12 @@ void populateFoldTensorSubsetIntoVectorTransferPatterns(
 void populateMergeConsecutiveInsertExtractSlicePatterns(
     RewritePatternSet &patterns);
 
+/// Appends patterns that are used to bubble up tensor.extract slice op above
+/// its producer. When used as cleanup patterns of tile and fuse, enables fusing
+/// the producer with the consumer even if the producer does not implement the
+/// tiling interface.
+void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
+
 /// Populates `patterns` with patterns that drop redundant tensor.insert_slice
 /// rank expansions.
 void populateDropRedundantInsertSliceRankExpansionPatterns(
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 51d1df52598c7..5146bebe0108e 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
     RewritePatternSet patterns(context);
     tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
     tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
+    tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
     tileAndFuseOptions.cleanupPatterns = std::move(patterns);
   }
 
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSlice.cpp b/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSlice.cpp
new file mode 100644
index 0000000000000..a0d3c6d25bbe8
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/BubbleUpExtractSlice.cpp
@@ -0,0 +1,207 @@
+//===- BubbleUpExtractSlice.cpp ---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Swap a `tensor.extract_slice` with the producer of the source in some cases
+// where that is valid. When used as cleanup patterns of tile and fuse, enables
+// fusing the producer with the consumer even if the producer does not implement
+// the tiling interface.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Arith/Utils/Utils.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/Dialect/Tensor/Utils/Utils.h"
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/ValueBoundsOpInterface.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+/// Converts `tensor.extract_slice(tensor.expand_shape)` to
+/// `tensor.expand_shape(tensor.extract_slice)`.
+/// For this transformation to be possible, the slice must be fully contiguous
+/// within each reassociation group of the expand_shape. If the transformation
+/// is not possible, or if the slice is rank reducting, the function returns
+/// failure.
+///
+/// Example:
+/// ```
+/// %reshape = tensor.expand_shape %in [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x16x32xf32> to tensor<2x4x2x8x4x2x4xf32>
+/// %slice = tensor.extract_slice %reshape ...
+///     tensor<2x4x2x8x4x2x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+///
+/// // The transformation is possible because each reassociation group has a
+/// // contiguous slice. (i.e., [2x4->2x4], [2x8->1x5], [4x2x4->1x1x4])
+/// // After the transformation:
+///
+/// %slice = tensor.extract_slice %in ...
+///     tensor<8x16x32xf32> to tensor<8x5x4xf32>
+/// %reshape = tensor.expand_shape %slice [[0, 1], [2, 3], [4, 5, 6]]
+///     tensor<8x5x4xf32> to tensor<2x4x1x5x1x1x4xf32>
+/// ```
+static LogicalResult
+swapExpandShapeWithSlice(RewriterBase &rewriter,
+                         tensor::ExpandShapeOp expandShapeOp,
+                         tensor::ExtractSliceOp sliceOp) {
+  SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
+
+  if (static_cast<size_t>(sliceOp.getResultType().getRank()) != sizes.size()) {
+    return rewriter.notifyMatchFailure(sliceOp,
+                                       "unimplemented: rank reducing slice");
+  }
+
+  // Helper variables and function for accumulating the new offset and length
+  // values.
+  Location loc = expandShapeOp->getLoc();
+  AffineExpr d0, d1, d2;
+  bindDims(rewriter.getContext(), d0, d1, d2);
+  // Multiply two integers.
+  auto mul = [&](OpFoldResult v1, OpFoldResult v2) {
+    auto mulMap = AffineMap::get(2, 0, {d0 * d1});
+    return affine::makeComposedFoldedAffineApply(rewriter, loc, mulMap,
+                                                 {v1, v2});
+  };
+
+  SmallVector<OpFoldResult> outputShape =
+      getMixedValues(expandShapeOp.getStaticOutputShape(),
+                     expandShapeOp.getOutputShape(), rewriter);
+
+  auto isZeroOffsetAndFullSize = [](OpFoldResult offset, OpFoldResult sliceSize,
+                                    OpFoldResult size) {
+    if (!isConstantIntValue(offset, 0))
+      return false;
+    FailureOr<bool> maybeEqual =
+        ValueBoundsConstraintSet::areEqual(sliceSize, size);
+    return llvm::succeeded(maybeEqual) && maybeEqual.value();
+  };
+
+  // First verify that this is a full slice of the expanded tensor.
+  for (const ReassociationIndices &indices :
+       expandShapeOp.getReassociationIndices()) {
+    int64_t i = 0;
+    int64_t e = indices.size();
+    // Find the first expanded dim after the first dim with non-unit extracted
+    // size.
+    for (; i < e; ++i) {
+      if (!isConstantIntValue(sizes[indices[i]], 1)) {
+        // +1 to skip the first non-unit size dim.
+        i++;
+        break;
+      }
+    }
+
+    // Verify that all subsequent dimensions extract the full size of the
+    // source tensor.
+    for (; i < e; ++i) {
+      int64_t expandedDim = indices[i];
+      if (!isZeroOffsetAndFullSize(offsets[expandedDim], sizes[expandedDim],
+                                   outputShape[expandedDim])) {
+        return rewriter.notifyMatchFailure(
+            sliceOp, "Not a contiguous slice of the expanded tensor.");
+      }
+    }
+  }
+
+  // Compute new offsets, lengths, and strides.
+  SmallVector<OpFoldResult> newOffsets, newLengths, newStrides;
+  for (const ReassociationIndices &indices :
+       expandShapeOp.getReassociationIndices()) {
+    OpFoldResult newSize = rewriter.getIndexAttr(1);
+    SmallVector<OpFoldResult> basis, delinOffsets;
+
+    int64_t i = 0;
+    int64_t e = indices.size();
+    // Offset = cumulative product of leading unit extracted dims.
+    for (; i < e; ++i) {
+      int64_t expandedDim = indices[i];
+      if (!isConstantIntValue(sizes[expandedDim], 1))
+        break;
+
+      basis.push_back(outputShape[expandedDim]);
+      delinOffsets.push_back(offsets[expandedDim]);
+    }
+
+    if (i != e) {
+      int64_t expandedDim = indices[i];
+      basis.push_back(outputShape[expandedDim]);
+      delinOffsets.push_back(offsets[expandedDim]);
+      newSize = sizes[expandedDim];
+      i++;
+    }
+
+    for (; i < e; ++i) {
+      OpFoldResult fullSize = outputShape[indices[i]];
+      basis.push_back(fullSize);
+      delinOffsets.push_back(rewriter.getIndexAttr(0));
+      newSize = mul(newSize, fullSize);
+    }
+    SmallVector<Value> offsetVals =
+        llvm::map_to_vector(delinOffsets, [&](OpFoldResult ofr) {
+          return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
+        });
+    OpFoldResult newOffset = rewriter
+                                 .create<affine::AffineLinearizeIndexOp>(
+                                     loc, offsetVals, basis, /*disjoint=*/true)
+                                 .getResult();
+    newOffsets.push_back(newOffset);
+    newLengths.push_back(newSize);
+
+    // Only unit stride supported.
+    newStrides.push_back(rewriter.getIndexAttr(1));
+  }
+
+  // The shape of the result can be obtained from the sizes passed in.
+  SmallVector<Value> dynDims;
+  SmallVector<int64_t> shape;
+  dispatchIndexOpFoldResults(sizes, dynDims, shape);
+  RankedTensorType resultType = RankedTensorType::get(
+      shape, expandShapeOp.getResultType().getElementType());
+
+  // Create a new ExtractSliceOp and ExpandShapeOp.
+  Value newSliceOp = rewriter.create<tensor::ExtractSliceOp>(
+      loc, expandShapeOp.getSrc(), newOffsets, newLengths, newStrides);
+  auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
+      loc, resultType, newSliceOp, expandShapeOp.getReassociationIndices(),
+      sizes);
+  rewriter.replaceOp(sliceOp, newExpandShapeOp);
+  return success();
+}
+
+namespace {
+
+struct SwapExpandShapeWithSlicePattern
+    : public OpRewritePattern<tensor::ExtractSliceOp> {
+  using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    auto expandOp = sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
+    if (!expandOp) {
+      return failure();
+    }
+
+    if (!sliceOp.hasUnitStride()) {
+      return rewriter.notifyMatchFailure(sliceOp,
+                                         "unsupported: non-unit stride");
+    }
+
+    return swapExpandShapeWithSlice(rewriter, expandOp, sliceOp);
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateBubbleUpExtractSliceOpPatterns(
+    RewritePatternSet &patterns) {
+  patterns.add<SwapExpandShapeWithSlicePattern>(patterns.getContext());
+}
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index cc6275fee671a..634cc93a08352 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   RewriteAsConstant.cpp
   SwapExtractSliceWithProducerPatterns.cpp
   SubsetInsertionOpInterfaceImpl.cpp
+  BubbleUpExtractSlice.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
index ac1ca9319d335..22796611c5934 100644
--- a/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-fuse.mlir
@@ -278,3 +278,141 @@ module attributes {transform.with_named_sequence} {
     transform.yield
   }
 }
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:     scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[Z]]] by (2, 3, 10)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [5] [1] : tensor<60xf32> to tensor<5xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 1, 5]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+func.func @swap_expand_shape_with_extract_slice(%0: tensor<60xf32>) -> tensor<2x3x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [2, 3, 10] : tensor<60xf32> into tensor<2x3x10xf32>
+  %empty = tensor.empty() : tensor<2x3x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<2x3x10xf32>) outs(%empty : tensor<2x3x10xf32>) -> tensor<2x3x10xf32>
+  return %exp : tensor<2x3x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:3 = transform.structured.fuse %0 [1, 1, 5] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_full_inner_dim
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]]{{.*}} by (3, 4, 10)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX]]] [20] [1] : tensor<120xf32> to tensor<20xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2]] output_shape [1, 2, 10]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+func.func @swap_expand_shape_with_extract_slice_full_inner_dim(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
+  %empty = tensor.empty() : tensor<3x4x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
+  return %exp : tensor<3x4x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:2 = transform.structured.fuse %0 [1, 2, 0] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_full_inner_dim
+//     CHECK: tensor.expand_shape
+//     CHECK: scf.for
+//     CHECK:   scf.for
+//     CHECK:     scf.for
+//     CHECK:       linalg.exp
+func.func @swap_expand_shape_with_extract_slice_full_inner_dim(%0: tensor<120xf32>) -> tensor<3x4x10xf32> {
+  %expand = tensor.expand_shape %0 [[0, 1, 2]] output_shape [3, 4, 10] : tensor<120xf32> into tensor<3x4x10xf32>
+  %empty = tensor.empty() : tensor<3x4x10xf32>
+  %exp = linalg.exp ins(%expand : tensor<3x4x10xf32>) outs(%empty : tensor<3x4x10xf32>) -> tensor<3x4x10xf32>
+  return %exp : tensor<3x4x10xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:3 = transform.structured.fuse %0 [1, 2, 5] interchange [0, 1, 2] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+// CHECK-LABEL: func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims
+//     CHECK: %[[C0:.+]] = arith.constant 0 : index
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:   scf.for %[[Y:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:     scf.for %[[Z:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       scf.for %[[W:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:       %[[LINEAR_IDX0:.+]] = affine.linearize_index disjoint [%[[X]], %[[Y]], %[[C0]]] by (3, 4, 10)
+//     CHECK:       %[[LINEAR_IDX1:.+]] = affine.linearize_index disjoint [%[[Z]], %[[W]]] by (7, 8)
+//     CHECK:       %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[%[[LINEAR_IDX0]], %[[LINEAR_IDX1]]] [20, 4] [1, 1] : tensor<120x56xf32> to tensor<20x4xf32>
+//     CHECK:       %[[EXPAND:.+]] = tensor.expand_shape %[[SLICE]] {{\[\[}}0, 1, 2], [3, 4]] output_shape [1, 2, 10, 1, 4]
+//     CHECK:       linalg.exp ins(%[[EXPAND]]
+module {
+  func.func @swap_expand_shape_with_extract_slice_multiple_expanded_dims(%0: tensor<120x56xf32>) -> tensor<3x4x10x7x8xf32> {
+    %expand = tensor.expand_shape %0 [[0, 1, 2], [3, 4]] output_shape [3, 4, 10, 7, 8] : tensor<120x56xf32> into tensor<3x4x10x7x8xf32>
+    %empty = tensor.empty() : tensor<3x4x10x7x8xf32>
+    %exp = linalg.exp ins(%expand : tensor<3x4x10x7x8xf32>) outs(%empty : tensor<3x4x10x7x8xf32>) -> tensor<3x4x10x7x8xf32>
+    return %exp : tensor<3x4x10x7x8xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:4 = transform.structured.fuse %0 [1, 2, 0, 1, 4] interchange [0, 1, 2, 3, 4] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">, !transform.any_op, !transform.any_op, !transform.any_op)
+    transform.yield 
+  }
+}
+
+// -----
+
+//     CHECK: scf.for %[[X:[A-Za-z0-9]+]] = {{.*}}
+//     CHECK:    %[[LINEAR_IDX:.+]] = affine.linearize_index disjoint [%[[X]], {{.*}} by (8, 32)
+//     CHECK:    %[[SLICE:.+]] = tensor.extract_slice %{{.*}}[0, 0, %[[LINEAR_IDX]]] [1, 1800, 32] [1, 1, 1] : tensor<1x1800x256xf32> to tensor<1x1800x32xf32>
+//     CHECK:    %[[ABS:.+]] = linalg.abs ins(%[[SLICE]]
+//     CHECK:    %[[EXPAND:.+]] = tensor.expand_shape %[[ABS]] {{\[\[}}0], [1], [2, 3]] output_shape [1, 1800, 1, 32]
+//     CHECK:    linalg.exp ins(%[[EXPAND]]
+module {
+  func.func @swap_expand_shape_with_extract_slice_and_fuse_with_expand_producer(%0: tensor<1x1800x256xf32>) -> tensor<1x1800x8x32xf32> {
+    %empty1 = tensor.empty() : tensor<1x1800x256xf32>
+    %exp1 = linalg.abs ins(%0 : tensor<1x1800x256xf32>) outs(%empty1 : tensor<1x1800x256xf32>) -> tensor<1x1800x256xf32>
+    %expand = tensor.expand_shape %exp1 [[0], [1], [2, 3]] output_shape [1, 1800, 8, 32] : tensor<1x1800x256xf32> into tensor<1x1800x8x32xf32>
+    %empty2 = tensor.empty() : tensor<1x1800x8x32xf32>
+    %exp2 = linalg.exp ins(%expand : tensor<1x1800x8x32xf32>) outs(%empty2 : tensor<1x1800x8x32xf32>) -> tensor<1x1800x8x32xf32>
+    return %exp2 : tensor<1x1800x8x32xf32>
+  }
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["linalg.exp"]} in %arg0 : (!transform.any_op) -> !transform.any_op
+    %transformed, %loops:1 = transform.structured.fuse %0 [0, 0, 1, 0] interchange [0, 1, 2, 3] apply_cleanup = true : 
+      (!transform.any_op) -> (!transform.any_op, !transform.op<"scf.for">)
+    transform.yield 
+  }
+}
+
+
+
+

@mgehre-amd
Copy link
Contributor

There are very similar patterns in mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp. How are they related? Should the new pattern also live in that file?

@ofri-frishman
Copy link
Contributor Author

I agree that the patterns in mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp are similar. The pattern I added is meant to swap between expand_shape and extract_slice which when used as a cleanup pattern for tile and fuse utility enables adding the expand_shape into a loop nest even though it does not implement the tiling interface.
I'm not familiar the the use of the other patterns, but don't think it is related to tiling.
I placed it in a separate file since I thought that other patterns that bubble up extract_slice through tensor ops could be placed there as well, so the emphasis is more on the extract_slice than on the expand_shape. But currently there is only one such pattern, so it could be placed together with the other reshape patterns.
But it would still require a separate populate function, since only this pattern should be added to the FuseOp cleanup patterns.
BTW there is another place with a similar pattern - in mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp there is a pattern that bubbles up a extract_slice through a linalg op.
Given this, if you think that the best place for this pattern is in mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp I can move it to there

@hanhanW hanhanW requested a review from qedawkins February 12, 2025 16:42
@@ -582,6 +582,7 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
RewritePatternSet patterns(context);
tensor::ExtractSliceOp::getCanonicalizationPatterns(patterns, context);
tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns);
tensor::populateBubbleUpExtractSliceOpPatterns(patterns);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion on this transform op, but I don't know if I would bucket this pattern as "cleanup." If there are no objections this is fine with me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, this transformation is "unblocking" tiling. As such, it feels much more impactful than merely "cleaning up".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that the pattern "unblocks" fusion of operations into a single tiling loop nest.
I think that also other patterns used for cleanup assist with that.
From what I understand, the concept of adding a possibility to cleanup between fusion steps was added precisely for this reason. As part of the thread in https://discourse.llvm.org/t/rfc-split-fusion-portions-of-the-tilinginterface-into-a-new-interface/81155 the idea to add such a step was proposed, and it was first implement in #109554.
The lit tests added in that PR show cases where previously fusion would have got stuck, but after the change they do not. So, I thought it would make sense to add this pattern to the cleanup patterns as well.
But perhaps I misunderstood something about the nature and use of the cleanup patterns.
How otherwise would you suggest to add the ability to apply this pattern between fusion steps of FuseOp?

@mgehre-amd
Copy link
Contributor

I agree that the patterns in mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp are similar. The pattern I added is meant to swap between expand_shape and extract_slice which when used as a cleanup pattern for tile and fuse utility enables adding the expand_shape into a loop nest even though it does not implement the tiling interface. I'm not familiar the the use of the other patterns, but don't think it is related to tiling. I placed it in a separate file since I thought that other patterns that bubble up extract_slice through tensor ops could be placed there as well, so the emphasis is more on the extract_slice than on the expand_shape. But currently there is only one such pattern, so it could be placed together with the other reshape patterns. But it would still require a separate populate function, since only this pattern should be added to the FuseOp cleanup patterns. BTW there is another place with a similar pattern - in mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp there is a pattern that bubbles up a extract_slice through a linalg op. Given this, if you think that the best place for this pattern is in mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp I can move it to there

Ok, make sense to me. Thanks for the analysis! Discoverability has been a big headache for me: There are many useful patterns, but its really hard to find whether a certain pattern already exists.
You can't probably do much about it in this PR.
Anyway, your patterns looks very useful, so thanks for the PR!

@banach-space
Copy link
Contributor

Discoverability has been a big headache for me: There are many useful patterns, but its really hard to find whether a certain pattern already exists.

Let me suggest a partial solution.

We should be able to identify relevant transformations by simply "grep"-ing for "interesting" Ops in the test directory. For this to work, we need to make sure that:

  • we have good (meaningful + descriptive) test file names,
  • we have good (meaningful + descriptive) function names,
  • we test patterns in isolation and document what patterns are being excercised (e.g. here).

As a stepping stone, @ofri-frishman, please make sure that the new pattern is tested in isolation. As for the right location, there seems to be 2 good candidates:

  • "mlir/lib/Dialect/Linalg/Transforms/BubbleUpExtractSlice.cpp",
  • "mlir/lib/Dialect/Tensor/Transforms/ReshapePatterns.cpp".

This pattern feels like "bubbling up" to me, though "swap" might be more fitting actually. Just to avoid bike-shedding, I suggest "BubbleUpExtractSlice.cpp", but lets note (in comments) that this is basically a "swap". We can re-visit later.

@ofri-frishman
Copy link
Contributor Author

Just a general question - currently in github I see that I am defined as a contributor to the llvm project and cannot add reviewers to review the PR and will not be able to land the PR once approved. Do you know who I can ask to give me the relevant permissions?

@mgehre-amd
Copy link
Contributor

Just a general question - currently in github I see that I am defined as a contributor to the llvm project and cannot add reviewers to review the PR and will not be able to land the PR once approved. Do you know who I can ask to give me the relevant permissions?

Hey, https://llvm.org/docs/DeveloperPolicy.html#obtaining-commit-access should have the relevant information.

@qedawkins
Copy link
Contributor

Re: landing the PR. I don't know if the policy has changed, but typically for first time contributors a reviewer can land the PR once approved/ready.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing this! I wrote most of the code here so it might be worth getting final approval from someone else. +1 to adding a separate testing op from transform.structured.fuse.

PatternRewriter &rewriter) const override {
auto expandOp = sliceOp.getSource().getDefiningOp<tensor::ExpandShapeOp>();
if (!expandOp) {
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return failure();
return rewriter.notifyMatchFailure(sliceOp,
"slice source not produced by expand_shape");

auto newExpandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, resultType, newSliceOp, expandShapeOp.getReassociationIndices(),
sizes);
rewriter.replaceOp(sliceOp, newExpandShapeOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can use replaceOpWithNewOp

@ofri-frishman
Copy link
Contributor Author

Thanks for contributing this! I wrote most of the code here so it might be worth getting final approval from someone else. +1 to adding a separate testing op from transform.structured.fuse.

I was actually wondering how best to go about that and wasn't aware of the option to add test ops, I was thinking about adding a test pass. You think the best way to have a test that isolates the pattern is via a test op (I guess you mean a transform op)? If so, any chance you could point me to an example of such an op?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants