Skip to content

Commit

Permalink
pushdown aggregation's pre-projection ahead expand node
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Sep 6, 2024
1 parent 37d09c1 commit 4a7ee54
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
)
}

// Move the pre-prejection for a aggregation ahead of the expand node
// for example, select a, b, sum(c+d) from t group by a, b with cube
def enablePushdownPreProjectionAheadExpand(): Boolean = {
SparkEnv.get.conf.getBoolean(
"spark.gluten.sql.columnar.backend.ch.enable_pushdown_preprojection_ahead_expand",
true
)
}

override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ private object CHRuleApi {
injector.injectTransform(_ => EliminateLocalSort)
injector.injectTransform(_ => CollapseProjectExecTransformer)
injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session))
injector.injectTransform(c => PushdownAggregatePreProjectionAheadExpand.apply(c.session))
injector.injectTransform(
c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session))
injector.injectTransform(c => InsertTransitions(c.outputsColumnar))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.execution._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

// If there is an expression (not a attribute) in an aggregation function's
// parameters. It will introduce a pr-projection to calculate the expression
// at first, and make all the parameters be attributes.
// If it's a aggregation with grouping set, this pre-projection is pushed after
// expand operator. This is not efficent, we cannot move this pre-projection
// before the expand operator.
case class PushdownAggregatePreProjectionAheadExpand(session: SparkSession)
extends Rule[SparkPlan]
with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
if (!CHBackendSettings.enablePushdownPreProjectionAheadExpand) {
return plan
}
plan.transformUp {
case hashAggregate: CHHashAggregateExecTransformer =>
tryPushdownAggregatePreProject(hashAggregate)
}
}

def isGroupingColumn(e: NamedExpression): Boolean = {
e.isInstanceOf[AttributeReference] && e
.asInstanceOf[AttributeReference]
.name
.startsWith(VirtualColumn.groupingIdName, 0)
}

def dropGroupingColumn(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.filter(!isGroupingColumn(_))
}

def isAttributeOrLiteral(e: Expression): Boolean = {
e match {
case _: Attribute | _: BoundReference | _: Literal => true
case _ => false
}
}

def tryPushdownAggregatePreProject(plan: SparkPlan): SparkPlan = {
val hashAggregate = plan.asInstanceOf[CHHashAggregateExecTransformer]
val originGroupingKeys = dropGroupingColumn(hashAggregate.groupingExpressions)
// make things simple, if any grouping key is not attribute, don't change anything
if (!originGroupingKeys.forall(isAttributeOrLiteral(_))) {
return hashAggregate
}
hashAggregate.child match {
case project @ ProjectExecTransformer(expressions, expand: ExpandExecTransformer) =>
val rootChild = expand.child

// This could not happen
if (rootChild.output.exists(isGroupingColumn(_))) {
return hashAggregate
}
// drop the goruping id column
val aheadProjectExprs = dropGroupingColumn(project.projectList)
val originInputAttributes = aheadProjectExprs.filter(e => isAttributeOrLiteral(e))

val preProjectExprs = aheadProjectExprs.filter(e => !isAttributeOrLiteral(e))
if (preProjectExprs.length == 0) {
return hashAggregate
}

// If the expression involves grouping keys, don't change anything
// This is should not happen.
if (preProjectExprs.exists(e => originGroupingKeys.exists(e.references.contains(_)))) {
return hashAggregate
}

// The new ahead project node will take rootChild's output and preProjectExprs as the
// the projection expressions.
val aheadProject = ProjectExecTransformer(rootChild.output ++ preProjectExprs, rootChild)
val aheadProjectOuput = aheadProject.output

val preProjectOutputAttrs = aheadProjectOuput.filter(
e =>
!originInputAttributes.exists(_.exprId.equals(e.asInstanceOf[NamedExpression].exprId)))

val newExpandProjections = expand.projections.map {
exprs => exprs ++ preProjectOutputAttrs
}
val newExpandOutput = expand.output ++ preProjectOutputAttrs
val newExpand = ExpandExecTransformer(newExpandProjections, newExpandOutput, aheadProject)

hashAggregate.withNewChildren(Seq(newExpand))
case _ => plan
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -547,5 +547,21 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
compareResultsAgainstVanillaSpark(sql, true, { _ => })
spark.sql("drop table cross_join_t")
}

test("Pushdown aggregation pre-projection ahead expand") {
spark.sql("create table t1(a bigint, b bigint, c bigint, d bigint) using parquet")
spark.sql("insert into t1 values(1,2,3,4), (1,2,4,5), (1,3,4,5), (2,3,4,5)")
var sql = """
| select a, b , sum(d+c) from t1 group by a,b with cube
| order by a,b
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
sql = """
| select a, b , sum(a+c), sum(b+d) from t1 group by a,b with cube
| order by a,b
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
spark.sql("drop table t1")
}
}
// scalastyle:off line.size.limit

0 comments on commit 4a7ee54

Please sign in to comment.