diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala index a3b74366fc7b..52e278b3dace 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/CommonSubexpressionEliminateRule.scala @@ -28,6 +28,11 @@ import org.apache.spark.sql.internal.SQLConf import scala.collection.mutable +// If you want to debug CommonSubexpressionEliminateRule, you can: +// 1. replace all `logTrace` to `logError` +// 2. append two options to spark config +// --conf spark.sql.planChangeLog.level=error +// --conf spark.sql.planChangeLog.batches=all class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf) extends Rule[LogicalPlan] with Logging { @@ -121,7 +126,12 @@ class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf) if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) { addToEquivalentExpressions(expr, equivalentExpressions) } else { - equivalentExpressions.addExprTree(expr) + expr match { + case alias: Alias => + equivalentExpressions.addExprTree(alias.child) + case _ => + equivalentExpressions.addExprTree(expr) + } } }) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala index f165d7aef69c..cbc3aed3607d 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala @@ -985,19 +985,19 @@ class GlutenClickHouseHiveTableSuite } } - test("GLUTEN-4333: fix CSE in aggregate operator") { - def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit - tag: ClassTag[T]): Unit = { - if (sparkVersion.equals("3.3")) { - assert( - getExecutedPlan(df).count( - plan => { - plan.getClass == tag.runtimeClass - }) == count, - s"executed plan: ${getExecutedPlan(df)}") - } + def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit + tag: ClassTag[T]): Unit = { + if (sparkVersion.equals("3.3")) { + assert( + getExecutedPlan(df).count( + plan => { + plan.getClass == tag.runtimeClass + }) == count, + s"executed plan: ${getExecutedPlan(df)}") } + } + test("GLUTEN-4333: fix CSE in aggregate operator") { val createTableSql = """ |CREATE TABLE `test_cse`( @@ -1262,4 +1262,66 @@ class GlutenClickHouseHiveTableSuite compareResultsAgainstVanillaSpark(selectSql, true, _ => {}) sql(s"drop table if exists $tableName") } + + test("GLUTEN-7054: Fix exception when CSE meets common alias expression") { + val createTableSql = """ + |CREATE TABLE test_tbl_7054 ( + | day STRING, + | event_id STRING, + | event STRUCT< + | event_info: MAP + | > + |) STORED AS PARQUET; + |""".stripMargin + + val insertDataSql = """ + |INSERT INTO test_tbl_7054 + |VALUES + | ('2024-08-27', '011441004', + | STRUCT(MAP('type', '1', 'action', '8', 'value_vmoney', '100'))), + | ('2024-08-27', '011441004', + | STRUCT(MAP('type', '2', 'action', '8', 'value_vmoney', '200'))), + | ('2024-08-27', '011441004', + | STRUCT(MAP('type', '4', 'action', '8', 'value_vmoney', '300'))); + |""".stripMargin + + val selectSql = """ + |SELECT + | COALESCE(day, 'all') AS daytime, + | COALESCE(type, 'all') AS type, + | COALESCE(value_money, 'all') AS value_vmoney, + | SUM(CASE + | WHEN type IN (1, 2) AND action = 8 THEN value_vmoney + | ELSE 0 + | END) / 60 AS total_value_vmoney + |FROM ( + | SELECT + | day, + | type, + | NVL(CAST(value_vmoney AS BIGINT), 0) AS value_money, + | action, + | type, + | CAST(value_vmoney AS BIGINT) AS value_vmoney + | FROM ( + | SELECT + | day, + | event.event_info["type"] AS type, + | event.event_info["action"] AS action, + | event.event_info["value_vmoney"] AS value_vmoney + | FROM test_tbl_7054 + | WHERE + | day = '2024-08-27' + | AND event_id = '011441004' + | AND event.event_info["type"] IN (1, 2, 4) + | ) a + |) b + |GROUP BY + | day, type, value_money + |""".stripMargin + + spark.sql(createTableSql) + spark.sql(insertDataSql) + runQueryAndCompare(selectSql)(df => checkOperatorCount[ProjectExecTransformer](3)(df)) + spark.sql("DROP TABLE test_tbl_7054") + } }