Skip to content

Commit

Permalink
[CORE] Fix wrong fallback cost (apache#3967)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you authored Dec 7, 2023
1 parent a935707 commit 1b1d0d3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,15 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
}.size == 2,
df.queryExecution.executedPlan)
}

runQueryAndCompare("select c1, count(*) from tmp1 group by c1") {
df =>
assert(
collect(df.queryExecution.executedPlan) {
case h: HashAggregateExecTransformer => h
}.size == 2,
df.queryExecution.executedPlan)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,21 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
*/
def countStageFallbackCostInternal(plan: SparkPlan): Unit = {
plan match {
case _: GlutenPlan if plan.children.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined =>
plan.children
case glutenPlan: GlutenPlan =>
val leaves = glutenPlan.collectLeaves()
leaves
.filter(_.isInstanceOf[InMemoryTableScanExec])
.foreach {
// For this case, table cache will internally execute ColumnarToRow if
// we make the stage fall back.
case child if InMemoryTableScanHelper.isGlutenTableCache(child) =>
case tableCache if InMemoryTableScanHelper.isGlutenTableCache(tableCache) =>
stageFallbackCost = stageFallbackCost + 1
// For other case, table cache will save internal RowToColumnar if we make
// the stage fall back.
case _ =>
stageFallbackCost = stageFallbackCost - 1
}
case _: GlutenPlan if plan.children.find(_.isInstanceOf[QueryStageExec]).isDefined =>
plan.children
leaves
.filter(_.isInstanceOf[QueryStageExec])
.foreach {
case stage: QueryStageExec
Expand Down Expand Up @@ -212,15 +212,17 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
return None
}

val netFallbackNum = if (isAdaptiveContext) {
countFallback(plan) - countStageFallbackCost(plan)
val fallbackNum = countFallback(plan)
val fallbackCost = if (isAdaptiveContext) {
countStageFallbackCost(plan)
} else {
countFallback(plan)
0
}
val netFallbackNum = fallbackNum - fallbackCost
if (netFallbackNum >= fallbackThreshold) {
Some(
s"Fallback policy is taking effect, net fallback number: $netFallbackNum, " +
s"threshold: $fallbackThreshold")
s"fallback num: $fallbackNum, cost: $fallbackCost, threshold: $fallbackThreshold")
} else {
None
}
Expand Down

0 comments on commit 1b1d0d3

Please sign in to comment.