diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 86a69f8422808..163f7568f7131 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -376,6 +376,15 @@ object CHBackendSettings extends BackendSettingsApi with Logging { ) } + // Move the pre-prejection for a aggregation ahead of the expand node + // for example, select a, b, sum(c+d) from t group by a, b with cube + def enablePushdownPreProjectionAheadExpand(): Boolean = { + SparkEnv.get.conf.getBoolean( + "spark.gluten.sql.columnar.backend.ch.enable_pushdown_preprojection_ahead_expand", + true + ) + } + override def enableNativeWriteFiles(): Boolean = { GlutenConfig.getConf.enableNativeWriter.getOrElse(false) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index fb5147157d94c..550044d3798c8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -76,6 +76,7 @@ private object CHRuleApi { injector.injectTransform(_ => EliminateLocalSort) injector.injectTransform(_ => CollapseProjectExecTransformer) injector.injectTransform(c => RewriteSortMergeJoinToHashJoinRule.apply(c.session)) + injector.injectTransform(c => PushdownExtraProjectionBeforeExpand.apply(c.session)) injector.injectTransform( c => SparkPlanRules.extendedColumnarRule(c.conf.extendedColumnarTransformRules)(c.session)) injector.injectTransform(c => InsertTransitions(c.outputsColumnar)) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index c517afcb29056..dbaab25939ab0 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -547,5 +547,20 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { compareResultsAgainstVanillaSpark(sql, true, { _ => }) spark.sql("drop table cross_join_t") } + + test("Pushdown aggregation pre-projection ahead expand") { + spark.sql("create table t1(a bigint, b bigint, c bigint, d bigint) using parquet") + spark.sql("insert into t1 values(1,2,3,4), (1,2,4,5), (1,3,4,5), (2,3,4,5)") + var sql = """ + | select a, b , sum(d+c) from t1 group by a,b with cube + | order by a,b + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + sql = """ + | select a, b , sum(a+c), sum(b+d) from t1 group by a,b with cube + | order by a,b + |""".stripMargin + spark.sql("drop table t1") + } } // scalastyle:off line.size.limit