Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Util] Fix an assert getting reached for certain nested loops in HoistIntoGlobals #19576

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -203,7 +203,7 @@ class ConstExprHoistingPolicy {
}
void enableHoist() {
assert(outcome == UNDECIDED &&
"can only disable hoisting of an undecided decision");
"can only enable hoisting of an undecided decision");
outcome = ENABLE_HOIST;
}

Original file line number Diff line number Diff line change
@@ -22,8 +22,6 @@ static void populateEscapingProducers(Operation *parentOp,
ConstExprOpInfo &info) {
SmallPtrSet<Operation *, 8> containedOps;
parentOp->walk<WalkOrder::PreOrder>([&](Operation *itOp) {
containedOps.insert(parentOp);

// For the outer-most op, consider that all operands escape.
if (itOp == parentOp) {
info.producers.insert(itOp->getOperands().begin(),
@@ -33,8 +31,9 @@ static void populateEscapingProducers(Operation *parentOp,
: WalkResult::advance();
}

// For nested operations, only consider that they escape if they are
// defined outside of the parent.
containedOps.insert(itOp->getParentOp());
// A nested operation escapes if every operand is defined outside contained
// ops.
for (Value operand : itOp->getOperands()) {
Block *block = operand.getParentBlock();
if (!containedOps.contains(block->getParentOp())) {
Original file line number Diff line number Diff line change
@@ -103,17 +103,13 @@ class HoistIntoGlobalsPass
continue;
auto walkRes = funcOp.walk<WalkOrder::PreOrder>([&](Operation *iterOp) {
// We only want to look at const-expr ops (non roots) since they may
// have interesting escapes. Early exit here for efficiency.
// have interesting escapes. Early exit here if the op has no
// ConstValueInfo or its first result cannot be hoisted.
auto *iterInfo = constExprs.lookup(iterOp);
if (!iterInfo)
if (!iterInfo || policy.getDecision(iterInfo)->getOutcome() !=
ConstExprHoistingPolicy::ENABLE_HOIST)
return WalkResult::advance();
for (Value constExprResult : iterOp->getResults()) {
auto *resultInfo = constExprs.lookup(constExprResult);
assert(resultInfo && "must have const-expr info");
if (policy.getDecision(resultInfo)->getOutcome() !=
ConstExprHoistingPolicy::ENABLE_HOIST) {
continue;
}
if (failed(hoistConstExpr(constExprResult, hoistedMap, moduleSymbols,
constExprs))) {
return WalkResult::interrupt();
Original file line number Diff line number Diff line change
@@ -364,3 +364,30 @@ module @nested_program_const_expr {
}
}
}

// -----

// Prior to this patch, a bug caused %3#0 to be considered an escaping producer for %1.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use issue numbers when referencing things - "this patch" won't mean much after this lands :)

the description here is a bit too much for a test - it encodes assumptions/situations that are very specific to the current implementation of the code and will get out of date fast - it's useful to include comments as to what a test is verifying but it doesn't need the whole back story and prior behavior, just what's expected - if the test ever starts failing as someone is changing code they care about what situation they are trying to make work, not how it did/didn't work previously

we generally want to avoid "it doesn't crash" tests as it doesn't help anyone coming along working on the code - "not crashing" is a weak test that doesn't prove the behavior does anything but not crash - those are better for large bulk test corpuses or indirectly via e2e tests - if adding a test and fixing code then a test should be added for the behavior being modified/fixed/etc. here, for example, whatever ops or attributes caused the crash need to be CHECKed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Ben. I'll create an issue to link with and look for a different way to test these changes. I found it difficult to add a test since I'm essentially only changing the analysis phase of this pass: the actual behavior on IR should not be affected.

Testing a batched aten.multinomial op e2e would be sufficient to cover this, but that might also need some changes from #19563 and #19556 to pass.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revisiting this after a little while. I'm a bit stuck on figuring out how to test these changes. The bug I encountered was incredibly specific. The simplest reproducer of which not only requires an op with multiple results, but one which is multiply-nested within other ops.

Does it makes sense to add a new e2e compile test to cover this highly specific case? This seems expensive considering the locality of the changes.

There are existing e2e tests in external testing suites (e.g. the test migraphx_agentmodel__AgentModel in the e2eshark ONNX model test suite) that are affected by this change. Is it acceptable to not add any tests in this PR, but rather put a link in the commit description to an issue (e.g. nod-ai/SHARK-ModelDev#876 or a fresh issue in IREE) and indicate that the results of this change are already being tested elsewhere?

// This means that %3#0 got assigned a ConstValueInfo when expanding, but %3#1 did not (it is unused).
// HoistIntoGlobalsPass includes an assertion that, for each *op* succeeding ConstValueInfo lookup,
// every result of the op also has ConstValueInfo. This assert caused the compiler to abort when checking %3#1.
// This lit test simply verifies that the pass does not crash.

// CHECK-LABEL: @nested_bodies_unused_result_no_crash
module @nested_bodies_unused_result_no_crash {
util.func public @main() -> tensor<i32> {
%c0_i32 = arith.constant 0 : i32
%c1_i32 = arith.constant 1 : i32
%c10_i32 = arith.constant 10 : i32
%0 = tensor.empty() : tensor<i32>
%1 = scf.for %arg0 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg1 = %0) -> (tensor<i32>) : i32 {
%2 = scf.for %arg2 = %c0_i32 to %c10_i32 step %c1_i32 iter_args(%arg3 = %arg1) -> (tensor<i32>) : i32 {
%3:2 = "iree_unregistered.const_expr"(%arg0, %arg2) : (i32, i32) -> (i32, i32)
%inserted = tensor.insert %3#0 into %arg3[] : tensor<i32>
scf.yield %inserted : tensor<i32>
}
scf.yield %2 : tensor<i32>
}
util.return %1 : tensor<i32>
}
}