Skip to content

Commit

Permalink
fixed: duplicated cols in group by
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Dec 6, 2024
1 parent e34914d commit a6b223e
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ object CHRuleApi {
SparkPlanRules.extendedColumnarRule(c.glutenConf.extendedColumnarTransformRules)(
c.session)))
injector.injectPostTransform(c => InsertTransitions.create(c.outputsColumnar, CHBatch))
injector.injectPostTransform(c => RemoveDuplicatedColumns.apply(c.session))

// Gluten columnar: Fallback policies.
injector.injectFallbackPolicy(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,11 +164,11 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging {
resultExpressions)
CHHashAggregateExecTransformer(
requiredChildDistributionExpressions,
groupingExpressions.distinct,
groupingExpressions,
aggregateExpressions,
aggregateAttributes,
initialInputBufferOffset,
replacedResultExpressions.distinct,
replacedResultExpressions,
child
)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.gluten.execution._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.CHColumnarToRowExec

/*
* CH doesn't support will for duplicate columns in the a block.
* Most of the cases that introduce duplicate columns are from group by.
*/
case class RemoveDuplicatedColumns(session: SparkSession) extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
visitPlan(plan)
}

private def visitPlan(plan: SparkPlan): SparkPlan = {
plan match {
case c2r @ CHColumnarToRowExec(hashAgg: CHHashAggregateExecTransformer) =>
// This is a special case. Use the result from aggregation as the input of sink.
// We need to make the schema same as the input of sink.
val newChildren = hashAgg.children.map(visitPlan)
val newHashAgg = uniqueHashAggregateColumns(hashAgg)
.withNewChildren(newChildren)
.asInstanceOf[CHHashAggregateExecTransformer]
if (newHashAgg.resultExpressions.length != hashAgg.resultExpressions.length) {
val project = ProjectExecTransformer(hashAgg.resultExpressions, newHashAgg)
c2r.copy(child = project)
} else {
c2r.copy(child = newHashAgg)
}
case hashAgg: CHHashAggregateExecTransformer =>
val newChildren = hashAgg.children.map(visitPlan)
val newHashAgg = uniqueHashAggregateColumns(hashAgg)
newHashAgg.withNewChildren(newChildren)
case shuffle @ ColumnarShuffleExchangeExec(
HashPartitioning(hashExpressions, partitionNum),
_,
_,
_,
_) =>
val newChildren = shuffle.children.map(visitPlan)
val uniqueHashExpressions = uniqueExpressions(hashExpressions)
if (uniqueHashExpressions.length != hashExpressions.length) {
shuffle
.copy(outputPartitioning = HashPartitioning(uniqueHashExpressions, partitionNum))
.withNewChildren(newChildren)
} else {
shuffle.withNewChildren(newChildren)
}
case _ =>
plan.withNewChildren(plan.children.map(visitPlan))
}
}

private def unwrapAliasNamedExpression(e: NamedExpression): NamedExpression = {
e match {
case a: Alias =>
if (a.child.isInstanceOf[NamedExpression]) {
a.child.asInstanceOf[NamedExpression]
} else {
a
}
case _ => e
}
}
private def unwrapAliasExpression(e: Expression): Expression = {
e match {
case a: Alias =>
if (a.child.isInstanceOf[Expression]) {
a.child.asInstanceOf[Expression]
} else {
a
}
case _ => e
}
}

private def uniqueNamedExpressions(
groupingExpressions: Seq[NamedExpression]): Seq[NamedExpression] = {
var addedExpression = Seq[NamedExpression]()
groupingExpressions.foreach {
e =>
val unwrapped = unwrapAliasNamedExpression(e)
if (
!addedExpression.exists(_.semanticEquals(unwrapped)) && !unwrapped.isInstanceOf[Literal]
) {
addedExpression = addedExpression :+ unwrapped
}
}
addedExpression
}

private def uniqueExpressions(expressions: Seq[Expression]): Seq[Expression] = {
var addedExpression = Seq[Expression]()
expressions.foreach {
e =>
val unwrapped = unwrapAliasExpression(e)
if (
!addedExpression.exists(_.semanticEquals(unwrapped)) && !unwrapped.isInstanceOf[Literal]
) {
addedExpression = addedExpression :+ unwrapped
}
}
addedExpression
}

private def uniqueHashAggregateColumns(
hashAgg: CHHashAggregateExecTransformer): CHHashAggregateExecTransformer = {
val newGroupingExpressions = uniqueNamedExpressions(hashAgg.groupingExpressions)
val newResultExpressions = uniqueNamedExpressions(hashAgg.resultExpressions)
if (newResultExpressions.length != hashAgg.resultExpressions.length) {
hashAgg
.copy(
groupingExpressions = newGroupingExpressions,
resultExpressions = newResultExpressions)
} else {
hashAgg
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -252,4 +252,5 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit
}
})
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -570,5 +570,25 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
", split(concat('a|b|c', cast(id as string)), '|') from range(10)"
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
test("GLUTEN-8142 duplicated columns in group by") {
sql("create table test_8142 (day string, rtime int, uid string, owner string) using parquet")
sql("insert into test_8142 values ('2024-09-01', 123, 'user1', 'owner1')")
sql("insert into test_8142 values ('2024-09-01', 123, 'user1', 'owner1')")
sql("insert into test_8142 values ('2024-09-02', 567, 'user2', 'owner2')")
compareResultsAgainstVanillaSpark(
"""
|select days, rtime, uid, owner, day1
|from (
| select day1 as days, rtime, uid, owner, day1
| from (
| select distinct coalesce(day, "today") as day1, rtime, uid, owner
| from test_8142 where day = '2024-09-01'
| )) group by days, rtime, uid, owner, day1
|""".stripMargin,
true,
{ _ => }
)
sql("drop table test_8142")
}
}
// scalastyle:off line.size.limit

0 comments on commit a6b223e

Please sign in to comment.