Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-7054][CH] Fix cse alias issues #7084

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import org.apache.spark.sql.internal.SQLConf

import scala.collection.mutable

// If you want to debug CommonSubexpressionEliminateRule, you can:
// 1. replace all `logTrace` to `logError`
// 2. append two options to spark config
// --conf spark.sql.planChangeLog.level=error
// --conf spark.sql.planChangeLog.batches=all
class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
extends Rule[LogicalPlan]
with Logging {
Expand Down Expand Up @@ -121,7 +126,12 @@ class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
addToEquivalentExpressions(expr, equivalentExpressions)
} else {
equivalentExpressions.addExprTree(expr)
expr match {
case alias: Alias =>
equivalentExpressions.addExprTree(alias.child)
case _ =>
equivalentExpressions.addExprTree(expr)
}
}
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,19 +985,19 @@ class GlutenClickHouseHiveTableSuite
}
}

test("GLUTEN-4333: fix CSE in aggregate operator") {
def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit
tag: ClassTag[T]): Unit = {
if (sparkVersion.equals("3.3")) {
assert(
getExecutedPlan(df).count(
plan => {
plan.getClass == tag.runtimeClass
}) == count,
s"executed plan: ${getExecutedPlan(df)}")
}
def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit
tag: ClassTag[T]): Unit = {
if (sparkVersion.equals("3.3")) {
assert(
getExecutedPlan(df).count(
plan => {
plan.getClass == tag.runtimeClass
}) == count,
s"executed plan: ${getExecutedPlan(df)}")
}
}

test("GLUTEN-4333: fix CSE in aggregate operator") {
val createTableSql =
"""
|CREATE TABLE `test_cse`(
Expand Down Expand Up @@ -1262,4 +1262,66 @@ class GlutenClickHouseHiveTableSuite
compareResultsAgainstVanillaSpark(selectSql, true, _ => {})
sql(s"drop table if exists $tableName")
}

test("GLUTEN-7054: Fix exception when CSE meets common alias expression") {
val createTableSql = """
|CREATE TABLE test_tbl_7054 (
| day STRING,
| event_id STRING,
| event STRUCT<
| event_info: MAP<STRING, STRING>
| >
|) STORED AS PARQUET;
|""".stripMargin

val insertDataSql = """
|INSERT INTO test_tbl_7054
|VALUES
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '1', 'action', '8', 'value_vmoney', '100'))),
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '2', 'action', '8', 'value_vmoney', '200'))),
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '4', 'action', '8', 'value_vmoney', '300')));
|""".stripMargin

val selectSql = """
|SELECT
| COALESCE(day, 'all') AS daytime,
| COALESCE(type, 'all') AS type,
| COALESCE(value_money, 'all') AS value_vmoney,
| SUM(CASE
| WHEN type IN (1, 2) AND action = 8 THEN value_vmoney
| ELSE 0
| END) / 60 AS total_value_vmoney
|FROM (
| SELECT
| day,
| type,
| NVL(CAST(value_vmoney AS BIGINT), 0) AS value_money,
| action,
| type,
| CAST(value_vmoney AS BIGINT) AS value_vmoney
| FROM (
| SELECT
| day,
| event.event_info["type"] AS type,
| event.event_info["action"] AS action,
| event.event_info["value_vmoney"] AS value_vmoney
| FROM test_tbl_7054
| WHERE
| day = '2024-08-27'
| AND event_id = '011441004'
| AND event.event_info["type"] IN (1, 2, 4)
| ) a
|) b
|GROUP BY
| day, type, value_money
|""".stripMargin

spark.sql(createTableSql)
spark.sql(insertDataSql)
runQueryAndCompare(selectSql)(df => checkOperatorCount[ProjectExecTransformer](3)(df))
spark.sql("DROP TABLE test_tbl_7054")
}
}
Loading