Skip to content

Commit

Permalink
[GLUTEN-7054][CH] Fix cse alias issues (apache#7084)
Browse files Browse the repository at this point in the history
* fix cse alias issues

* fix issue apache#7054

* fix uts
  • Loading branch information
taiyang-li authored Sep 3, 2024
1 parent 22bb0b3 commit 6e0b119
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import org.apache.spark.sql.internal.SQLConf

import scala.collection.mutable

// If you want to debug CommonSubexpressionEliminateRule, you can:
// 1. replace all `logTrace` to `logError`
// 2. append two options to spark config
// --conf spark.sql.planChangeLog.level=error
// --conf spark.sql.planChangeLog.batches=all
class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
extends Rule[LogicalPlan]
with Logging {
Expand Down Expand Up @@ -121,7 +126,12 @@ class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
addToEquivalentExpressions(expr, equivalentExpressions)
} else {
equivalentExpressions.addExprTree(expr)
expr match {
case alias: Alias =>
equivalentExpressions.addExprTree(alias.child)
case _ =>
equivalentExpressions.addExprTree(expr)
}
}
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,19 +985,19 @@ class GlutenClickHouseHiveTableSuite
}
}

test("GLUTEN-4333: fix CSE in aggregate operator") {
def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit
tag: ClassTag[T]): Unit = {
if (sparkVersion.equals("3.3")) {
assert(
getExecutedPlan(df).count(
plan => {
plan.getClass == tag.runtimeClass
}) == count,
s"executed plan: ${getExecutedPlan(df)}")
}
def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit
tag: ClassTag[T]): Unit = {
if (sparkVersion.equals("3.3")) {
assert(
getExecutedPlan(df).count(
plan => {
plan.getClass == tag.runtimeClass
}) == count,
s"executed plan: ${getExecutedPlan(df)}")
}
}

test("GLUTEN-4333: fix CSE in aggregate operator") {
val createTableSql =
"""
|CREATE TABLE `test_cse`(
Expand Down Expand Up @@ -1262,4 +1262,66 @@ class GlutenClickHouseHiveTableSuite
compareResultsAgainstVanillaSpark(selectSql, true, _ => {})
sql(s"drop table if exists $tableName")
}

test("GLUTEN-7054: Fix exception when CSE meets common alias expression") {
val createTableSql = """
|CREATE TABLE test_tbl_7054 (
| day STRING,
| event_id STRING,
| event STRUCT<
| event_info: MAP<STRING, STRING>
| >
|) STORED AS PARQUET;
|""".stripMargin

val insertDataSql = """
|INSERT INTO test_tbl_7054
|VALUES
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '1', 'action', '8', 'value_vmoney', '100'))),
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '2', 'action', '8', 'value_vmoney', '200'))),
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '4', 'action', '8', 'value_vmoney', '300')));
|""".stripMargin

val selectSql = """
|SELECT
| COALESCE(day, 'all') AS daytime,
| COALESCE(type, 'all') AS type,
| COALESCE(value_money, 'all') AS value_vmoney,
| SUM(CASE
| WHEN type IN (1, 2) AND action = 8 THEN value_vmoney
| ELSE 0
| END) / 60 AS total_value_vmoney
|FROM (
| SELECT
| day,
| type,
| NVL(CAST(value_vmoney AS BIGINT), 0) AS value_money,
| action,
| type,
| CAST(value_vmoney AS BIGINT) AS value_vmoney
| FROM (
| SELECT
| day,
| event.event_info["type"] AS type,
| event.event_info["action"] AS action,
| event.event_info["value_vmoney"] AS value_vmoney
| FROM test_tbl_7054
| WHERE
| day = '2024-08-27'
| AND event_id = '011441004'
| AND event.event_info["type"] IN (1, 2, 4)
| ) a
|) b
|GROUP BY
| day, type, value_money
|""".stripMargin

spark.sql(createTableSql)
spark.sql(insertDataSql)
runQueryAndCompare(selectSql)(df => checkOperatorCount[ProjectExecTransformer](3)(df))
spark.sql("DROP TABLE test_tbl_7054")
}
}

0 comments on commit 6e0b119

Please sign in to comment.