diff --git a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp index b1c90d2f6cef..f8e13f2fa0e8 100644 --- a/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp +++ b/compiler/plugins/input/Torch/InputConversion/BindSymbolicShapes.cpp @@ -30,6 +30,14 @@ namespace mlir::iree_compiler::TorchInput { namespace { +// We aribtrarily say that unbounded dimensions in a torch program cannot +// exceed 53bits, making the maximum safe dimension 9007199254740991. The +// astute reader will note that this is also the maximum safe value in +// JavaScript, which also "happens" to be the largest mantissa value in a +// 64bit double. We need a maximum and in the absence of a better choice, +// with this one we are at least in good company. +static constexpr uint64_t MAX_DIM_VALUE = (static_cast(1) << 53) - 1; + // Torch "binds" symbolic shape information to all tensors in the program // which are not static. It does this by emitting side-effecting // torch.bind_symbolic_shape ops which are backed by torch.symbolic_int ops @@ -95,15 +103,9 @@ class BindSymbolicShapesPass final auto maxVal = symbolDefOp.getMaxValAttr(); if (minVal && maxVal) { uint64_t minValInt = minVal.getValue().getZExtValue(); - uint64_t maxValInt = maxVal.getValue().getZExtValue(); - // Note that torch represents open ranges in strange ways with various - // magic numbers in the high range of the uint64_t type. We somewhat - // arbitrarily say that anything over a fourth of the uint64_t - // range (which is half of the positive int64_t range, should these have - // originated as signed quantities), is a ridiculously large number not - // suitable as a shape dimension, and we drop the hint. - if (maxValInt >= minValInt && - maxValInt < std::numeric_limits::max() / 4) { + uint64_t maxValInt = + std::min(maxVal.getValue().getZExtValue(), MAX_DIM_VALUE); + if (maxValInt >= minValInt) { // Note that in Torch, min values are "weird" because they encode // some special cases about broadcast behavior. Here we just discard // them, but in the future, there may be more to derive here. @@ -220,8 +222,8 @@ class BindSymbolicShapesPass final for (auto [pos, symbolValue] : llvm::enumerate(symbols)) { const SymbolInfo &symbolInfo = symbolInfos.at(symbolValue); if (!symbolInfo.minMaxBounds) { - lowerBounds.push_back({}); - upperBounds.push_back({}); + lowerBounds.push_back(1); + upperBounds.push_back(MAX_DIM_VALUE); } else { lowerBounds.push_back(symbolInfo.minMaxBounds->first); upperBounds.push_back(symbolInfo.minMaxBounds->second); diff --git a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir index 8f78bf7dcb0c..699b6dbf6d60 100644 --- a/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir +++ b/compiler/plugins/input/Torch/InputConversion/test/bind_symbolic_shapes.mlir @@ -153,14 +153,16 @@ module @unsupported_non_symbolic { // ----- // Torch uses high values to signal unbounded ranges. Ensure they are -// suppressed. +// clamped. // CHECK-LABEL: @torch_unbounded_max_range module @torch_unbounded_max_range { func.func @main(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) { - // CHECK-NOT: util.assume.int + // CHECK: util.assume.int {{.*}} torch.bind_symbolic_shape %arg0, [%0, %1], affine_map<()[s0, s1] -> (s0, s1)> : !torch.vtensor<[?,?],f32> + // CHECK: util.assume.int {{.*}} torch.bind_symbolic_shape %arg1, [%0, %1], affine_map<()[s0, s1] -> (s0, s1 * 10)> : !torch.vtensor<[?,?],f32> return } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp index fd1df8e0f2ac..3de051be39dd 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.cpp @@ -1101,6 +1101,40 @@ void printShapedFunctionSignature(OpAsmPrinter &p, Operation *op, namespace mlir::iree_compiler::IREE::Util { +//===----------------------------------------------------------------------===// +// util.align +//===----------------------------------------------------------------------===// + +void AlignOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + auto constantAlignment = argRanges[1].getConstantValue(); + // Note that for non constant alignment, there may still be something we + // want to infer, but this is left for the future. + if (constantAlignment) { + // We can align the range directly. + // (value + (alignment - 1)) & ~(alignment - 1) + // https://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding + APInt umin = argRanges[0].umin(); + APInt umax = argRanges[0].umax(); + APInt one(constantAlignment->getBitWidth(), 1); + APInt alignmentM1 = *constantAlignment - one; + APInt alignmentM1Inv = ~alignmentM1; + auto align = [&](APInt value) -> APInt { + return (value + alignmentM1) & alignmentM1Inv; + }; + setResultRange(getResult(), + ConstantIntRanges::fromUnsigned(align(umin), align(umax))); + } +} + +void AlignOp::inferResultDivisibility(ArrayRef argDivs, + SetIntDivisibilityFn setResultDivs) { + auto alignmentDiv = argDivs[1]; + if (alignmentDiv.isUninitialized()) + return; + setResultDivs(getResult(), alignmentDiv.getValue()); +} + //===----------------------------------------------------------------------===// // util.assume.int //===----------------------------------------------------------------------===// @@ -1120,39 +1154,45 @@ AssumeIntOp::getOperandAssumptions(unsigned operandIndex) { std::pair, std::optional> AssumeIntOp::getUnionedUnsignedRange(unsigned operandIndex) { auto assumptions = getOperandAssumptions(operandIndex); - std::optional uminUnion; - std::optional umaxUnion; + uint64_t uminUnion = std::numeric_limits::max(); + int uminCount = 0; + uint64_t umaxUnion = std::numeric_limits::min(); + int umaxCount = 0; for (auto assumption : assumptions) { auto umin = assumption.getUmin(); auto umax = assumption.getUmax(); if (umin) { uminUnion = std::min( - *umin, uminUnion ? *uminUnion : std::numeric_limits::max()); + *umin, uminUnion ? uminUnion : std::numeric_limits::max()); + uminCount += 1; } if (umax) { umaxUnion = std::max( - *umax, umaxUnion ? *umaxUnion : std::numeric_limits::min()); + *umax, umaxUnion ? umaxUnion : std::numeric_limits::min()); + umaxCount += 1; } } - return std::make_pair(uminUnion, umaxUnion); + return std::make_pair(uminCount > 0 && uminCount == assumptions.size() + ? std::optional(uminUnion) + : std::nullopt, + umaxCount > 0 && umaxCount == assumptions.size() + ? std::optional(umaxUnion) + : std::nullopt); } -// Gets the unioned divisor for an operand. If there are multiple divisor -// assumptions, the gcd of all of them is returned. If there are no -// divisor assumptions, std::nullopt is returned. std::optional AssumeIntOp::getUnionedUnsignedDivisor(unsigned operandIndex) { auto assumptions = getOperandAssumptions(operandIndex); std::optional divisorUnion; for (auto assumption : assumptions) { auto divisor = assumption.getUdiv(); - if (divisor) { - if (divisorUnion) - divisorUnion = std::gcd(*divisor, *divisorUnion); - else - divisorUnion = *divisor; - } + if (!divisor) + return std::nullopt; + if (divisorUnion) + divisorUnion = std::gcd(*divisor, *divisorUnion); + else + divisorUnion = *divisor; } return divisorUnion; } diff --git a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td index 648e4f4dad58..aaa10da27005 100644 --- a/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td +++ b/compiler/src/iree/compiler/Dialect/Util/IR/UtilOps.td @@ -336,7 +336,9 @@ def OpGroupAddressOffsetArithmeticOps : OpDocGroup { let opDocGroup = OpGroupAddressOffsetArithmeticOps in { def Util_AlignOp : Util_PureOp<"align", [ - SameOperandsAndResultType + SameOperandsAndResultType, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods ]> { let summary = "Aligns up to a power-of-two alignment if required"; let description = [{ @@ -504,14 +506,15 @@ def Util_AssumeIntOp : Util_PureOp<"assume.int", [ // Gets the unioned unsigned range for an operand. If there are multiple // assumptions for the operand, this will return the bounding range for - // them all. If there is no umin/umax, then std::nullopt will be returned - // for that position. + // them all. If there is no umin/umax for any row in the set, then + // std::nullopt will be returned for that position. std::pair, std::optional> getUnionedUnsignedRange(unsigned operandIndex); // Gets the unioned divisor for an operand. If there are multiple divisor // assumptions, the gcd of all of them is returned. If there are no - // divisor assumptions, std::nullopt is returned. + // divisor assumptions or if there is not a udiv for any row, std::nullopt + // is returned. std::optional getUnionedUnsignedDivisor(unsigned operandIndex); }]; diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel index a98135d57bdb..da3c44f08cb1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/BUILD.bazel @@ -54,6 +54,8 @@ iree_compiler_cc_library( "//compiler/src/iree/compiler/Utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:AffineDialect", + "@llvm-project//mlir:AffineTransforms", + "@llvm-project//mlir:AffineUtils", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithTransforms", diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt index d233f11e0278..a8542c40ae64 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/CMakeLists.txt @@ -43,6 +43,8 @@ iree_cc_library( ::PassesIncGen LLVMSupport MLIRAffineDialect + MLIRAffineTransforms + MLIRAffineUtils MLIRAnalysis MLIRArithDialect MLIRArithTransforms diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp index 79e61a174f4c..022beaac439b 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/OptimizeIntArithmetic.cpp @@ -12,6 +12,9 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" #include "mlir/Analysis/DataFlowFramework.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/Transforms/Transforms.h" +#include "mlir/Dialect/Affine/Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/IR/Matchers.h" @@ -108,19 +111,35 @@ struct ConvertOpToUnsigned : public OpRewritePattern { // optimizations, it can be useful to eliminate them when possible. //===----------------------------------------------------------------------===// +// Matches IR like: +// %5 = arith.addi %0, %1 : int64 +// %6 = arith.index_castui %5 : int64 to index +// +// And moves the index_castui to the producer's operands: +// %3 = arith.index_castui %0 : int64 to index +// %4 = arith.index_castui %1 : int64 to index +// %5 = arith.addi %3, %4 : index +// struct ConvertUnsignedI64IndexCastProducerToIndex : public OpRewritePattern { ConvertUnsignedI64IndexCastProducerToIndex(MLIRContext *context, DataFlowSolver &solver) : OpRewritePattern(context), solver(solver) {} - LogicalResult matchAndRewrite(arith::IndexCastUIOp op, + LogicalResult matchAndRewrite(arith::IndexCastUIOp origIndexOp, PatternRewriter &rewriter) const override { - Type inType = op.getIn().getType(); - Type outType = op.getOut().getType(); + Type inType = origIndexOp.getIn().getType(); + Type outType = origIndexOp.getOut().getType(); if (!inType.isSignlessInteger(64) && isa(outType)) return failure(); + Operation *producer = origIndexOp.getIn().getDefiningOp(); + if (!producer) + return failure(); + auto producerResult = producer->getResult(0); + if (!producerResult.hasOneUse()) + return failure(); + auto pred = [&](Value v) -> bool { auto *result = solver.lookupState(v); if (!result || result->getValue().isUninitialized()) { @@ -137,7 +156,6 @@ struct ConvertUnsignedI64IndexCastProducerToIndex llvm::all_of(op->getResults(), pred); }; - Operation *producer = op.getIn().getDefiningOp(); if (!isa_and_present(producer)) @@ -145,6 +163,7 @@ struct ConvertUnsignedI64IndexCastProducerToIndex if (!isOpStaticallyLegal(producer)) return failure(); + // Make modifications. rewriter.modifyOpInPlace(producer, [&]() { rewriter.setInsertionPoint(producer); for (auto &operand : producer->getOpOperands()) { @@ -156,6 +175,8 @@ struct ConvertUnsignedI64IndexCastProducerToIndex } producer->getResult(0).setType(outType); }); + origIndexOp.getOut().replaceAllUsesWith(producer->getResult(0)); + rewriter.eraseOp(origIndexOp); return success(); } @@ -206,6 +227,52 @@ struct RemUIDivisibilityByConstant : public OpRewritePattern { DataFlowSolver &solver; }; +//===----------------------------------------------------------------------===// +// Affine expansion +// affine.apply expansion can fail after producing a lot of IR. Since this is +// a bad thing to be doing as part of our overall iteration, we do it as a +// preprocessing walk. This also lets it be well behaved with respect to +// error messaging, etc. We will likely replace this with a more integrated +// version at some point which can use the bounds analysis to avoid corners +// of the original. +//===----------------------------------------------------------------------===// + +void expandAffineOps(Operation *rootOp) { + IRRewriter rewriter(rootOp->getContext()); + rootOp->walk([&](affine::AffineApplyOp op) { + LLVM_DEBUG(dbgs() << "** Expand affine.apply: " << op << "\n"); + rewriter.setInsertionPoint(op); + auto maybeExpanded = + mlir::affine::expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), + llvm::to_vector<4>(op.getOperands())); + if (!maybeExpanded) { + LLVM_DEBUG(dbgs() << "** ERROR: Failed to expand affine.apply\n"); + return; + } + rewriter.replaceOp(op, *maybeExpanded); + }); +} + +//===----------------------------------------------------------------------===// +// General optimization patterns +//===----------------------------------------------------------------------===// + +struct ElideTruncOfIndexCast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::TruncIOp truncOp, + PatternRewriter &rewriter) const override { + Operation *producer = truncOp.getOperand().getDefiningOp(); + if (!producer) + return failure(); + if (!isa(producer)) + return failure(); + rewriter.replaceOpWithNewOp( + truncOp, truncOp.getResult().getType(), producer->getOperand(0)); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass setup //===----------------------------------------------------------------------===// @@ -270,6 +337,9 @@ class OptimizeIntArithmeticPass void runOnOperation() override { Operation *op = getOperation(); MLIRContext *ctx = op->getContext(); + + expandAffineOps(op); + DataFlowSolver solver; solver.load(); solver.load(); @@ -281,13 +351,15 @@ class OptimizeIntArithmeticPass arith::populateIntRangeOptimizationsPatterns(patterns, solver); // Populate canonicalization patterns. - auto arithDialectTypeID = - ctx->getOrLoadDialect()->getTypeID(); + auto arithDialect = ctx->getOrLoadDialect(); for (const RegisteredOperationName &name : ctx->getRegisteredOperations()) { - if (name.getDialect().getTypeID() == arithDialectTypeID) + if (&name.getDialect() == arithDialect) name.getCanonicalizationPatterns(patterns, ctx); } + // General optimization patterns. + patterns.add(ctx); + // Populate unsigned conversion patterns. patterns.add, diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir index e8c2740bb31f..3e2235b7f8f4 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/integer_divisibility.mlir @@ -34,3 +34,14 @@ util.func @remui_div_by_unrelated(%arg0 : index) -> index { %1 = arith.remui %0, %cst : index util.return %1 : index } + +// ----- +// A missing udiv in a multi-row assumption is treated as an unknown. +// CHECK-LABEL: @missing_udiv_skipped +util.func @missing_udiv_skipped(%arg0 : index) -> index { + // CHECK: arith.remui + %cst = arith.constant 16 : index + %0 = util.assume.int %arg0[, <>] : index + %1 = arith.remui %0, %cst : index + util.return %1 : index +} diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir index 4726941c42db..41b304a89c1f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/optimize_int_arithmetic.mlir @@ -27,6 +27,30 @@ util.func @index_lower_bound(%arg0 : index) -> i1 { util.return %1 : i1 } +// ----- +// If there is a missing umax in a multi-row assumption, then it must +// be treated as having no known upper bound. +// CHECK-LABEL: @missing_umax_skipped +util.func @missing_umax_skipped(%arg0 : index) -> i1 { + // CHECK: arith.cmpi + %cst = arith.constant 101 : index + %0 = util.assume.int %arg0[, ] : index + %1 = arith.cmpi ult, %0, %cst : index + util.return %1 : i1 +} + +// ----- +// If there is a missing umin in a multi-row assumption, then it must +// be treated as having no known lower bound. +// CHECK-LABEL: @missing_umin_skipped +util.func @missing_umin_skipped(%arg0 : index) -> i1 { + // CHECK: arith.cmpi + %cst = arith.constant 5 : index + %0 = util.assume.int %arg0[, ] : index + %1 = arith.cmpi ugt, %0, %cst : index + util.return %1 : i1 +} + // ----- // CHECK-LABEL: @index_indeterminate util.func @index_indeterminate(%arg0 : index) -> i1 { @@ -246,6 +270,20 @@ util.func @index_cast_i64_to_index_addi(%arg0 : index, %arg1 : index) -> index { util.return %3 : index } +// ----- +// Multi-use should not convert +// CHECK-LABEL: @index_cast_i64_to_index_addi_multiuse +util.func @index_cast_i64_to_index_addi_multiuse(%arg0 : index, %arg1 : index) -> index, i64 { + // CHECK: %[[ASSUME:.*]] = util.assume.int + %0 = util.assume.int %arg0 : index + // CHECK: arith.index_cast + // CHECK: arith.index_cast + %1 = arith.index_cast %0 : index to i64 + %2 = arith.addi %1, %1 : i64 + %3 = arith.index_cast %2 : i64 to index + util.return %3, %2 : index, i64 +} + // ----- // CHECK-LABEL: @index_cast_i64_to_index_ceildivsi util.func @index_cast_i64_to_index_ceildivsi(%arg0 : index, %arg1 : index) -> index { @@ -371,3 +409,56 @@ util.func @index_cast_i64_to_index_remsi(%arg0 : index, %arg1 : index) -> index %3 = arith.index_cast %2 : i64 to index util.return %3 : index } + +// ----- +// Truncate of an index cast can be folded into the index cast. +// CHECK-LABEL: @elide_trunc_of_index_castui +util.func @elide_trunc_of_index_castui(%arg0 : index) -> i32 { + %1 = arith.index_castui %arg0 : index to i64 + %2 = arith.trunci %1 : i64 to i32 + // CHECK: %[[RESULT:.*]] = arith.index_castui %arg0 : index to i32 + // CHECH: util.return %[[RESULT]] + util.return %2 : i32 +} + +// ----- +// CHECK-LABEL: @elide_trunc_of_index_cast +util.func @elide_trunc_of_index_cast(%arg0 : index) -> i32 { + %1 = arith.index_cast %arg0 : index to i64 + %2 = arith.trunci %1 : i64 to i32 + // CHECK: %[[RESULT:.*]] = arith.index_castui %arg0 : index to i32 + // CHECH: util.return %[[RESULT]] + util.return %2 : i32 +} + +// ----- +// CHECK-LABEL: @util_align_bounds_div +util.func @util_align_bounds_div(%arg0 : index, %arg1 : index) -> index, index, index, i1, i1 { + %0 = util.assume.int %arg0 : index + %1 = util.assume.int %arg1 : index + // CHECK-DAG: %[[ZERO:.*]] = arith.constant 0 + // CHECK-DAG: %[[FALSE:.*]] = arith.constant false + // CHECK-DAG: %[[TRUE:.*]] = arith.constant true + // CHECK-DAG: %[[C64:.*]] = arith.constant 64 + // CHECK-DAG: %[[ASSUME:.*]] = util.assume.int %arg0 + // CHECK: %[[ALIGN:.*]] = util.align %[[ASSUME]], %[[C64]] + %2 = util.align %0, %1 : index + + // The result should be >= 64 and <= 128. + %c64 = arith.constant 64 : index + %c128 = arith.constant 128 : index + %lower = arith.cmpi uge, %2, %c64 : index // True + %upper = arith.cmpi ule, %2, %c128 : index // True + %under = arith.cmpi ult, %2, %c64 : index // False + %over = arith.cmpi ugt, %2, %c128 : index // False + %in_bounds = arith.andi %lower, %upper : i1 // True + %out_bounds = arith.andi %under, %over : i1 // False + + // And 64 should evenly divide it. + %rem64 = arith.remui %2, %c64 : index + // But 128 should not. + // CHECK: %[[REM128:.*]] = arith.remui + %rem128 = arith.remui %2, %c128 : index + // CHECK: util.return %[[ALIGN]], %[[ZERO]], %[[REM128]], %[[TRUE]], %[[FALSE]] + util.return %2, %rem64, %rem128, %in_bounds, %out_bounds : index, index, index, i1, i1 +}