From 7483883d4c70676acd8332e4ad45d78d08e1d916 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 6 Sep 2024 12:02:14 +0800 Subject: [PATCH] pushdown aggregation's pre-projection ahead expand node --- .../backendsapi/clickhouse/CHBackend.scala | 9 ++ .../backendsapi/clickhouse/CHRuleApi.scala | 1 + .../PushdownExtraProjectionBeforeExpand.scala | 115 ++++++++++++++++++ .../execution/GlutenClickHouseTPCHSuite.scala | 15 +++ 4 files changed, 140 insertions(+) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownExtraProjectionBeforeExpand.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 86a69f8422808..163f7568f7131 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -376,6 +376,15 @@ object CHBackendSettings extends BackendSettingsApi with Logging { ) } + // Move the pre-prejection for a aggregation ahead of the expand node + // for example, select a, b, sum(c+d) from t group by a, b with cube + def enablePushdownPreProjectionAheadExpand(): Boolean = { + SparkEnv.get.conf.getBoolean( + "spark.gluten.sql.columnar.backend.ch.enable_pushdown_preprojection_ahead_expand", + true + ) + } + override def enableNativeWriteFiles(): Boolean = { GlutenConfig.getConf.enableNativeWriter.getOrElse(false) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index fb5147157d94c..550044d3798c8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -76,6 +76,7 @@ private object CHRuleApi { injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session)) + injector.injectTransform(c => PushdownExtraProjectionBeforeExpand.apply(c.session)) injector.injectTransform( c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownExtraProjectionBeforeExpand.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownExtraProjectionBeforeExpand.scala new file mode 100644 index 0000000000000..5e462c711acb1 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/PushdownExtraProjectionBeforeExpand.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings +import org.apache.gluten.execution._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +// If there is an expression (not a attribute) in an aggregation function's +// parameters. It will introduce a pr-projection to calculate the expression +// at first, and make all the parameters be attributes. +// If it's a aggregation with grouping set, this pre-projection is pushed after +// expand operator. This is not efficent, we cannot move this pre-projection +// before the expand operator. +case class PushdownExtraProjectionBeforeExpand(session: SparkSession) + extends Rule[SparkPlan] + with Logging { + override def apply(plan: SparkPlan): SparkPlan = { + if (CHBackendSettings.enablePushdownPreProjectionAheadExpand) { + return plan + } + plan.transformUp { + case hashAggregate: CHHashAggregateExecTransformer => + tryPushdownAggregatePreProject(hashAggregate) + } + } + + def isGroupingColumn(e: NamedExpression): Boolean = { + e.isInstanceOf[AttributeReference] && e + .asInstanceOf[AttributeReference] + .name + .startsWith(VirtualColumn.groupingIdName, 0) + } + + def dropGroupingColumn(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { + expressions.filter(!isGroupingColumn(_)) + } + + def isAttributeOrLiteral(e: Expression): Boolean = { + e match { + case _: Attribute | _: BoundReference | _: Literal => true + case _ => false + } + } + + def tryPushdownAggregatePreProject(plan: SparkPlan): SparkPlan = { + val hashAggregate = plan.asInstanceOf[CHHashAggregateExecTransformer] + val originGroupingKeys = dropGroupingColumn(hashAggregate.groupingExpressions) + // make things simple, if any grouping key is not attribute, don't change anything + if (!originGroupingKeys.forall(isAttributeOrLiteral(_))) { + return hashAggregate + } + hashAggregate.child match { + case project @ ProjectExecTransformer(expressions, expand: ExpandExecTransformer) => + val rootChild = expand.child + + // This could not happen + if (rootChild.output.exists(isGroupingColumn(_))) { + return hashAggregate + } + // drop the goruping id column + val aheadProjectExprs = dropGroupingColumn(project.projectList) + val originInputAttributes = aheadProjectExprs.filter(e => isAttributeOrLiteral(e)) + + val preProjectExprs = aheadProjectExprs.filter(e => !isAttributeOrLiteral(e)) + if (preProjectExprs.length == 0) { + return hashAggregate + } + + // If the expression involves grouping keys, don't change anything + // This is should not happen. + if (preProjectExprs.exists(e => originGroupingKeys.exists(e.references.contains(_)))) { + return hashAggregate + } + + // The new ahead project node will take rootChild's output and preProjectExprs as the + // the projection expressions. + val aheadProject = ProjectExecTransformer(rootChild.output ++ preProjectExprs, rootChild) + val aheadProjectOuput = aheadProject.output + + val preProjectOutputAttrs = aheadProjectOuput.filter( + e => + !originInputAttributes.exists(_.exprId.equals(e.asInstanceOf[NamedExpression].exprId))) + + val newExpandProjections = expand.projections.map { + exprs => exprs ++ preProjectOutputAttrs + } + val newExpandOutput = expand.output ++ preProjectOutputAttrs + val newExpand = ExpandExecTransformer(newExpandProjections, newExpandOutput, aheadProject) + + hashAggregate.withNewChildren(Seq(newExpand)) + case _ => plan + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index c517afcb29056..dbaab25939ab0 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -547,5 +547,20 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { compareResultsAgainstVanillaSpark(sql, true, { _ => }) spark.sql("drop table cross_join_t") } + + test("Pushdown aggregation pre-projection ahead expand") { + spark.sql("create table t1(a bigint, b bigint, c bigint, d bigint) using parquet") + spark.sql("insert into t1 values(1,2,3,4), (1,2,4,5), (1,3,4,5), (2,3,4,5)") + var sql = """ + | select a, b , sum(d+c) from t1 group by a,b with cube + | order by a,b + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + sql = """ + | select a, b , sum(a+c), sum(b+d) from t1 group by a,b with cube + | order by a,b + |""".stripMargin + spark.sql("drop table t1") + } } // scalastyle:off line.size.limit