diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index 98cfa0e7547b6..141778688967b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -118,6 +118,7 @@ object CHRuleApi { SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarTransformRules)( c.session))) injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, CHBatch)) + injector.injectPostTransform(c => RemoveDuplicatedColumns.apply(c.session)) // Gluten columnar: Fallback policies. injector.injectFallbackPolicy( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index c2f91fa152148..de0680df10ab5 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -164,11 +164,11 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { resultExpressions) CHHashAggregateExecTransformer( requiredChildDistributionExpressions, - groupingExpressions.distinct, + groupingExpressions, aggregateExpressions, aggregateAttributes, initialInputBufferOffset, - replacedResultExpressions.distinct, + replacedResultExpressions, child ) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala new file mode 100644 index 0000000000000..c09deff0b03d7 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/RemoveDuplicatedColumns.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.execution._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.CHColumnarToRowExec + +/* + * CH doesn't support will for duplicate columns in the a block. + * Most of the cases that introduce duplicate columns are from group by. + */ +case class RemoveDuplicatedColumns(session: SparkSession) extends Rule[SparkPlan] with Logging { + override def apply(plan: SparkPlan): SparkPlan = { + visitPlan(plan) + } + + private def visitPlan(plan: SparkPlan): SparkPlan = { + plan match { + case c2r @ CHColumnarToRowExec(hashAgg: CHHashAggregateExecTransformer) => + // This is a special case. Use the result from aggregation as the input of sink. + // We need to make the schema same as the input of sink. + val newChildren = hashAgg.children.map(visitPlan) + val newHashAgg = uniqueHashAggregateColumns(hashAgg) + .withNewChildren(newChildren) + .asInstanceOf[CHHashAggregateExecTransformer] + if (newHashAgg.resultExpressions.length != hashAgg.resultExpressions.length) { + val project = ProjectExecTransformer(hashAgg.resultExpressions, newHashAgg) + c2r.copy(child = project) + } else { + c2r.copy(child = newHashAgg) + } + case hashAgg: CHHashAggregateExecTransformer => + val newChildren = hashAgg.children.map(visitPlan) + val newHashAgg = uniqueHashAggregateColumns(hashAgg) + newHashAgg.withNewChildren(newChildren) + case shuffle @ ColumnarShuffleExchangeExec( + HashPartitioning(hashExpressions, partitionNum), + _, + _, + _, + _) => + val newChildren = shuffle.children.map(visitPlan) + val uniqueHashExpressions = uniqueExpressions(hashExpressions) + if (uniqueHashExpressions.length != hashExpressions.length) { + shuffle + .copy(outputPartitioning = HashPartitioning(uniqueHashExpressions, partitionNum)) + .withNewChildren(newChildren) + } else { + shuffle.withNewChildren(newChildren) + } + case _ => + plan.withNewChildren(plan.children.map(visitPlan)) + } + } + + private def unwrapAliasNamedExpression(e: NamedExpression): NamedExpression = { + e match { + case a: Alias => + if (a.child.isInstanceOf[NamedExpression]) { + a.child.asInstanceOf[NamedExpression] + } else { + a + } + case _ => e + } + } + private def unwrapAliasExpression(e: Expression): Expression = { + e match { + case a: Alias => + if (a.child.isInstanceOf[Expression]) { + a.child.asInstanceOf[Expression] + } else { + a + } + case _ => e + } + } + + private def uniqueNamedExpressions( + groupingExpressions: Seq[NamedExpression]): Seq[NamedExpression] = { + var addedExpression = Seq[NamedExpression]() + groupingExpressions.foreach { + e => + val unwrapped = unwrapAliasNamedExpression(e) + if ( + !addedExpression.exists(_.semanticEquals(unwrapped)) && !unwrapped.isInstanceOf[Literal] + ) { + addedExpression = addedExpression :+ unwrapped + } + } + addedExpression + } + + private def uniqueExpressions(expressions: Seq[Expression]): Seq[Expression] = { + var addedExpression = Seq[Expression]() + expressions.foreach { + e => + val unwrapped = unwrapAliasExpression(e) + if ( + !addedExpression.exists(_.semanticEquals(unwrapped)) && !unwrapped.isInstanceOf[Literal] + ) { + addedExpression = addedExpression :+ unwrapped + } + } + addedExpression + } + + private def uniqueHashAggregateColumns( + hashAgg: CHHashAggregateExecTransformer): CHHashAggregateExecTransformer = { + val newGroupingExpressions = uniqueNamedExpressions(hashAgg.groupingExpressions) + val newResultExpressions = uniqueNamedExpressions(hashAgg.resultExpressions) + if (newResultExpressions.length != hashAgg.resultExpressions.length) { + hashAgg + .copy( + groupingExpressions = newGroupingExpressions, + resultExpressions = newResultExpressions) + } else { + hashAgg + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHNullableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHNullableSuite.scala index d5e1156ba9d20..b318de4e82ba9 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHNullableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHNullableSuite.scala @@ -252,4 +252,5 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit } }) } + } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala index bbe51ef3894c2..65a01dea30730 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSuite.scala @@ -570,5 +570,25 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { ", split(concat('a|b|c', cast(id as string)), '|') from range(10)" compareResultsAgainstVanillaSpark(sql, true, { _ => }) } + test("GLUTEN-8142 duplicated columns in group by") { + sql("create table test_8142 (day string, rtime int, uid string, owner string) using parquet") + sql("insert into test_8142 values ('2024-09-01', 123, 'user1', 'owner1')") + sql("insert into test_8142 values ('2024-09-01', 123, 'user1', 'owner1')") + sql("insert into test_8142 values ('2024-09-02', 567, 'user2', 'owner2')") + compareResultsAgainstVanillaSpark( + """ + |select days, rtime, uid, owner, day1 + |from ( + | select day1 as days, rtime, uid, owner, day1 + | from ( + | select distinct coalesce(day, "today") as day1, rtime, uid, owner + | from test_8142 where day = '2024-09-01' + | )) group by days, rtime, uid, owner, day1 + |""".stripMargin, + true, + { _ => } + ) + sql("drop table test_8142") + } } // scalastyle:off line.size.limit