From 3a6805e3b16125460ea7da5b43d2cbca76abcdad Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Thu, 16 Jan 2025 17:47:43 +0000 Subject: [PATCH] [GlobalOptimization] Do not hoist fill-like operations --- .../Dialect/LinalgExt/Utils/Utils.cpp | 20 +++++++++++++++++++ .../compiler/Dialect/LinalgExt/Utils/Utils.h | 4 ++++ .../ExternalInterfaces/UtilExternalModels.cpp | 3 ++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp index 9d0a27da7ba3..cad0c852bd98 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Builders.h" +#include "mlir/Transforms/RegionUtils.h" namespace mlir::iree_compiler::IREE::LinalgExt { @@ -363,6 +364,25 @@ bool isGatherlikeOp(Operation *op) { return true; } +bool isFillLikeOp(linalg::GenericOp linalgOp) { + // Check if there are any non-scalar inputs or non-scalar captures in the + // region. + for (Value input : linalgOp.getDpsInputs()) { + if (isa(input.getType())) { + return false; + } + } + + bool foundNonScalar = false; + visitUsedValuesDefinedAbove(linalgOp.getRegion(), [&](OpOperand *operand) { + if (isa(operand->get().getType())) { + foundNonScalar = true; + } + }); + + return !foundNonScalar; +} + FailureOr> getIGEMMContractionIndexingMaps(linalg::LinalgOp linalgOp) { MLIRContext *ctx = linalgOp->getContext(); diff --git a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h index c0609eb971e2..8c7a54ff37e3 100644 --- a/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h +++ b/compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h @@ -195,5 +195,9 @@ bool isBroadcastingOp(linalg::LinalgOp op); /// 2. `linalg.yield` consumes the result of a `tensor.extract_slice` bool isGatherlikeOp(Operation *op); +/// Returns true if the operation is a GenericOp that has no tensor inputs, +/// either as inputs or as implicit captures. +bool isFillLikeOp(linalg::GenericOp op); + } // namespace mlir::iree_compiler::IREE::LinalgExt #endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_ diff --git a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp index db5fa4007269..17ff95536ba4 100644 --- a/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp +++ b/compiler/src/iree/compiler/ExternalInterfaces/UtilExternalModels.cpp @@ -12,6 +12,7 @@ #include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtDialect.h" #include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h" +#include "iree/compiler/Dialect/LinalgExt/Utils/Utils.h" #include "iree/compiler/Dialect/Util/IR/UtilTypes.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -325,7 +326,7 @@ struct HoistableLinalgOpInterface } } } - return !linalg::isaFillOpInterface(genericOp).has_value(); + return !IREE::LinalgExt::isFillLikeOp(genericOp); } bool isAtomicallyHoistableOp(Operation *) const { return true; } bool isOperandHoistable(Operation *, OpOperand *) const { return true; }