diff --git a/backends-velox/src/test/scala/io/glutenproject/execution/FallbackSuite.scala b/backends-velox/src/test/scala/io/glutenproject/execution/FallbackSuite.scala index 7e8049e33a30..d04eb9d5ffb6 100644 --- a/backends-velox/src/test/scala/io/glutenproject/execution/FallbackSuite.scala +++ b/backends-velox/src/test/scala/io/glutenproject/execution/FallbackSuite.scala @@ -133,6 +133,15 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl }.size == 2, df.queryExecution.executedPlan) } + + runQueryAndCompare("select c1, count(*) from tmp1 group by c1") { + df => + assert( + collect(df.queryExecution.executedPlan) { + case h: HashAggregateExecTransformer => h + }.size == 2, + df.queryExecution.executedPlan) + } } } } diff --git a/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala b/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala index 98c0e9b4c77c..4e3f7ce58a15 100644 --- a/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala +++ b/gluten-core/src/main/scala/io/glutenproject/extension/ExpandFallbackPolicy.scala @@ -142,21 +142,21 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP */ def countStageFallbackCostInternal(plan: SparkPlan): Unit = { plan match { - case _: GlutenPlan if plan.children.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined => - plan.children + case glutenPlan: GlutenPlan => + val leaves = glutenPlan.collectLeaves() + leaves .filter(_.isInstanceOf[InMemoryTableScanExec]) .foreach { // For this case, table cache will internally execute ColumnarToRow if // we make the stage fall back. - case child if InMemoryTableScanHelper.isGlutenTableCache(child) => + case tableCache if InMemoryTableScanHelper.isGlutenTableCache(tableCache) => stageFallbackCost = stageFallbackCost + 1 // For other case, table cache will save internal RowToColumnar if we make // the stage fall back. case _ => stageFallbackCost = stageFallbackCost - 1 } - case _: GlutenPlan if plan.children.find(_.isInstanceOf[QueryStageExec]).isDefined => - plan.children + leaves .filter(_.isInstanceOf[QueryStageExec]) .foreach { case stage: QueryStageExec @@ -212,15 +212,17 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP return None } - val netFallbackNum = if (isAdaptiveContext) { - countFallback(plan) - countStageFallbackCost(plan) + val fallbackNum = countFallback(plan) + val fallbackCost = if (isAdaptiveContext) { + countStageFallbackCost(plan) } else { - countFallback(plan) + 0 } + val netFallbackNum = fallbackNum - fallbackCost if (netFallbackNum >= fallbackThreshold) { Some( s"Fallback policy is taking effect, net fallback number: $netFallbackNum, " + - s"threshold: $fallbackThreshold") + s"fallback num: $fallbackNum, cost: $fallbackCost, threshold: $fallbackThreshold") } else { None }