diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h index b914c599f967..978f9f1f52c1 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/ConstExpr.h @@ -203,7 +203,7 @@ class ConstExprHoistingPolicy { } void enableHoist() { assert(outcome == UNDECIDED && - "can only disable hoisting of an undecided decision"); + "can only enable hoisting of an undecided decision"); outcome = ENABLE_HOIST; } diff --git a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp index 57f4f29ea40d..207aabe108cb 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Analysis/Constant/OpOracle.cpp @@ -22,8 +22,6 @@ static void populateEscapingProducers(Operation *parentOp, ConstExprOpInfo &info) { SmallPtrSet containedOps; parentOp->walk([&](Operation *itOp) { - containedOps.insert(parentOp); - // For the outer-most op, consider that all operands escape. if (itOp == parentOp) { info.producers.insert(itOp->getOperands().begin(), @@ -33,8 +31,9 @@ static void populateEscapingProducers(Operation *parentOp, : WalkResult::advance(); } - // For nested operations, only consider that they escape if they are - // defined outside of the parent. + containedOps.insert(itOp->getParentOp()); + // A nested operation escapes if every operand is defined outside contained + // ops. for (Value operand : itOp->getOperands()) { Block *block = operand.getParentBlock(); if (!containedOps.contains(block->getParentOp())) { diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp index af8ae92fa8b3..f625e1f13b5a 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/HoistIntoGlobals.cpp @@ -103,17 +103,13 @@ class HoistIntoGlobalsPass continue; auto walkRes = funcOp.walk([&](Operation *iterOp) { // We only want to look at const-expr ops (non roots) since they may - // have interesting escapes. Early exit here for efficiency. + // have interesting escapes. Early exit here if the op has no + // ConstValueInfo or its first result cannot be hoisted. auto *iterInfo = constExprs.lookup(iterOp); - if (!iterInfo) + if (!iterInfo || policy.getDecision(iterInfo)->getOutcome() != + ConstExprHoistingPolicy::ENABLE_HOIST) return WalkResult::advance(); for (Value constExprResult : iterOp->getResults()) { - auto *resultInfo = constExprs.lookup(constExprResult); - assert(resultInfo && "must have const-expr info"); - if (policy.getDecision(resultInfo)->getOutcome() != - ConstExprHoistingPolicy::ENABLE_HOIST) { - continue; - } if (failed(hoistConstExpr(constExprResult, hoistedMap, moduleSymbols, constExprs))) { return WalkResult::interrupt(); diff --git a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir index cd2945552adf..ef6b78e8660f 100644 --- a/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir +++ b/compiler/src/iree/compiler/Dialect/Util/Transforms/test/hoist_into_globals.mlir @@ -364,3 +364,30 @@ module @nested_program_const_expr { } } } + +// ----- + +// Prior to this patch, a bug caused %3#0 to be considered an escaping producer for %1. +// This means that %3#0 got assigned a ConstValueInfo when expanding, but %3#1 did not (it is unused). +// HoistIntoGlobalsPass includes an assertion that, for each *op* succeeding ConstValueInfo lookup, +// every result of the op also has ConstValueInfo. This assert caused the compiler to abort when checking %3#1. +// This lit test simply verifies that the pass does not crash. + +// CHECK-LABEL: @nested_bodies_unused_result_no_crash +module @nested_bodies_unused_result_no_crash { + util.func public @main() -> tensor { + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c10_i32 = arith.constant 10 : i32 + %0 = tensor.empty() : tensor + %1 = scf.for %arg0 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg1 = %0) -> (tensor) : i32 { + %2 = scf.for %arg2 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg3 = %arg1) -> (tensor) : i32 { + %3:2 = "iree_unregistered.const_expr"(%arg0, %arg2) : (i32, i32) -> (i32, i32) + %inserted = tensor.insert %3#0 into %arg3[] : tensor + scf.yield %inserted : tensor + } + scf.yield %2 : tensor + } + util.return %1 : tensor + } +}