Skip to content

Commit

Permalink
Merge branch 'main' into gluten_7100
Browse files Browse the repository at this point in the history
  • Loading branch information
taiyang-li authored Sep 10, 2024
2 parents fec0eb4 + 597e1aa commit 4af372f
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
)
}

// Move the pre-prejection for a aggregation ahead of the expand node
// for example, select a, b, sum(c+d) from t group by a, b with cube
def enablePushdownPreProjectionAheadExpand(): Boolean = {
SparkEnv.get.conf.getBoolean(
"spark.gluten.sql.columnar.backend.ch.enable_pushdown_preprojection_ahead_expand",
true
)
}

override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ private object CHRuleApi {
injector.injectTransform(_ => EliminateLocalSort)
injector.injectTransform(_ => CollapseProjectExecTransformer)
injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session))
injector.injectTransform(c => PushdownAggregatePreProjectionAheadExpand.apply(c.session))
injector.injectTransform(
c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session))
injector.injectTransform(c => InsertTransitions(c.outputsColumnar))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.execution._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.SparkPlan

// If there is an expression (not a attribute) in an aggregation function's
// parameters. It will introduce a pr-projection to calculate the expression
// at first, and make all the parameters be attributes.
// If it's a aggregation with grouping set, this pre-projection is placed after
// expand operator. This is not efficient, we cannot move this pre-projection
// before the expand operator.
case class PushdownAggregatePreProjectionAheadExpand(session: SparkSession)
extends Rule[SparkPlan]
with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
if (!CHBackendSettings.enablePushdownPreProjectionAheadExpand) {
return plan
}
plan.transformUp {
case hashAggregate: CHHashAggregateExecTransformer =>
tryPushdownAggregatePreProject(hashAggregate)
}
}

def isGroupingColumn(e: NamedExpression): Boolean = {
e.isInstanceOf[AttributeReference] && e
.asInstanceOf[AttributeReference]
.name
.startsWith(VirtualColumn.groupingIdName, 0)
}

def dropGroupingColumn(expressions: Seq[NamedExpression]): Seq[NamedExpression] = {
expressions.filter(!isGroupingColumn(_))
}

def isAttributeOrLiteral(e: Expression): Boolean = {
e match {
case _: Attribute | _: BoundReference | _: Literal => true
case _ => false
}
}

def tryPushdownAggregatePreProject(plan: SparkPlan): SparkPlan = {
val hashAggregate = plan.asInstanceOf[CHHashAggregateExecTransformer]
val originGroupingKeys = dropGroupingColumn(hashAggregate.groupingExpressions)
// make things simple, if any grouping key is not attribute, don't change anything
if (!originGroupingKeys.forall(isAttributeOrLiteral(_))) {
return hashAggregate
}
hashAggregate.child match {
case project @ ProjectExecTransformer(expressions, expand: ExpandExecTransformer) =>
val rootChild = expand.child

// This could not happen
if (rootChild.output.exists(isGroupingColumn(_))) {
return hashAggregate
}
// drop the goruping id column
val aheadProjectExprs = dropGroupingColumn(project.projectList)
val originInputAttributes = aheadProjectExprs.filter(e => isAttributeOrLiteral(e))

val preProjectExprs = aheadProjectExprs.filter(e => !isAttributeOrLiteral(e))
if (preProjectExprs.length == 0) {
return hashAggregate
}

// If the expression involves grouping keys, don't change anything
// This should not happen.
if (preProjectExprs.exists(e => originGroupingKeys.exists(e.references.contains(_)))) {
return hashAggregate
}

// The new ahead project node will take rootChild's output and preProjectExprs as the
// the projection expressions.
val aheadProject = ProjectExecTransformer(rootChild.output ++ preProjectExprs, rootChild)
val aheadProjectOuput = aheadProject.output

val preProjectOutputAttrs = aheadProjectOuput.filter(
e =>
!originInputAttributes.exists(_.exprId.equals(e.asInstanceOf[NamedExpression].exprId)))

val newExpandProjections = expand.projections.map {
exprs => exprs ++ preProjectOutputAttrs
}
val newExpandOutput = expand.output ++ preProjectOutputAttrs
val newExpand = ExpandExecTransformer(newExpandProjections, newExpandOutput, aheadProject)

hashAggregate.withNewChildren(Seq(newExpand))
case _ => plan
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ object CHExpressionUtil {
ENCODE -> EncodeDecodeValidator(),
ARRAY_REPEAT -> DefaultValidator(),
ARRAY_REMOVE -> DefaultValidator(),
ARRAYS_ZIP -> DefaultValidator(),
DATE_FROM_UNIX_DATE -> DefaultValidator(),
MONOTONICALLY_INCREASING_ID -> DefaultValidator(),
SPARK_PARTITION_ID -> DefaultValidator(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,5 +547,21 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
compareResultsAgainstVanillaSpark(sql, true, { _ => })
spark.sql("drop table cross_join_t")
}

test("Pushdown aggregation pre-projection ahead expand") {
spark.sql("create table t1(a bigint, b bigint, c bigint, d bigint) using parquet")
spark.sql("insert into t1 values(1,2,3,4), (1,2,4,5), (1,3,4,5), (2,3,4,5)")
var sql = """
| select a, b , sum(d+c) from t1 group by a,b with cube
| order by a,b
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
sql = """
| select a, b , sum(a+c), sum(b+d) from t1 group by a,b with cube
| order by a,b
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })
spark.sql("drop table t1")
}
}
// scalastyle:off line.size.limit
Original file line number Diff line number Diff line change
Expand Up @@ -792,4 +792,12 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS
|""".stripMargin
runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
}

test("test function arrays_zip") {
val sql = """
|SELECT arrays_zip(array(id, id+1, id+2), array(id, id-1, id-2))
|FROM range(10)
|""".stripMargin
runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Shuffle, shuffle, arrayShuffle);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Range, range, range);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Flatten, flatten, sparkArrayFlatten);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArrayJoin, array_join, sparkArrayJoin);
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(ArraysZip, arrays_zip, arrayZipUnaligned);

// map functions
REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Map, map, map);
Expand Down
1 change: 0 additions & 1 deletion ep/build-clickhouse/src/package.sh
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ do
if [[ "$replace_dot" == "32" ]];then
continue # error: xxx are the same file
fi
cp -f "${PACKAGE_DIR_PATH}"/jars/spark32/protobuf-java-"${protobuf_version}".jar "${PACKAGE_DIR_PATH}"/jars/spark"${replace_dot}"
cp -f "${PACKAGE_DIR_PATH}"/jars/spark32/celeborn-client-spark-3-shaded_2.12-"${celeborn_version}".jar "${PACKAGE_DIR_PATH}"/jars/spark"${replace_dot}"
done

Expand Down

0 comments on commit 4af372f

Please sign in to comment.