diff --git a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala index f850b6f457ea0..c4b9a768fa844 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/extension/FlushableHashAggregateRule.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.extension -import org.apache.gluten.execution.{FlushableHashAggregateExecTransformer, HashAggregateExecTransformer, ProjectExecTransformer, RegularHashAggregateExecTransformer} +import org.apache.gluten.execution._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.aggregate.{Partial, PartialMerge} @@ -30,11 +30,12 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeLike * optimizations such as flushing and abandoning. */ case class FlushableHashAggregateRule(session: SparkSession) extends Rule[SparkPlan] { + import FlushableHashAggregateRule._ override def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case shuffle: ShuffleExchangeLike => + case ShuffleAndChild(shuffle, child) => // If an exchange follows a hash aggregate in which all functions are in partial mode, // then it's safe to convert the hash aggregate to flushable hash aggregate. - shuffle.child match { + child match { case HashAggPropagatedToShuffle(proj, agg) => shuffle.withNewChildren( Seq(proj.withNewChildren(Seq(FlushableHashAggregateExecTransformer( @@ -125,4 +126,16 @@ object FlushableHashAggregateRule { val distribution = ClusteredDistribution(agg.groupingExpressions) agg.child.outputPartitioning.satisfies(distribution) } + + private object ShuffleAndChild { + def unapply(plan: SparkPlan): Option[(SparkPlan, SparkPlan)] = plan match { + case s: ShuffleExchangeLike => + val child = s.child match { + case VeloxAppendBatchesExec(child, _) => child + case other => other + } + Some(s, child) + case other => None + } + } }