Skip to content

Commit

Permalink
Merge branch 'main' into fix-input-metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
ivoson authored Sep 3, 2024
2 parents 747a103 + 6e0b119 commit 40b444f
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ import org.apache.spark.sql.internal.SQLConf

import scala.collection.mutable

// If you want to debug CommonSubexpressionEliminateRule, you can:
// 1. replace all `logTrace` to `logError`
// 2. append two options to spark config
// --conf spark.sql.planChangeLog.level=error
// --conf spark.sql.planChangeLog.batches=all
class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
extends Rule[LogicalPlan]
with Logging {
Expand Down Expand Up @@ -121,7 +126,12 @@ class CommonSubexpressionEliminateRule(session: SparkSession, conf: SQLConf)
if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) {
addToEquivalentExpressions(expr, equivalentExpressions)
} else {
equivalentExpressions.addExprTree(expr)
expr match {
case alias: Alias =>
equivalentExpressions.addExprTree(alias.child)
case _ =>
equivalentExpressions.addExprTree(expr)
}
}
})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -985,19 +985,19 @@ class GlutenClickHouseHiveTableSuite
}
}

test("GLUTEN-4333: fix CSE in aggregate operator") {
def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit
tag: ClassTag[T]): Unit = {
if (sparkVersion.equals("3.3")) {
assert(
getExecutedPlan(df).count(
plan => {
plan.getClass == tag.runtimeClass
}) == count,
s"executed plan: ${getExecutedPlan(df)}")
}
def checkOperatorCount[T <: TransformSupport](count: Int)(df: DataFrame)(implicit
tag: ClassTag[T]): Unit = {
if (sparkVersion.equals("3.3")) {
assert(
getExecutedPlan(df).count(
plan => {
plan.getClass == tag.runtimeClass
}) == count,
s"executed plan: ${getExecutedPlan(df)}")
}
}

test("GLUTEN-4333: fix CSE in aggregate operator") {
val createTableSql =
"""
|CREATE TABLE `test_cse`(
Expand Down Expand Up @@ -1262,4 +1262,66 @@ class GlutenClickHouseHiveTableSuite
compareResultsAgainstVanillaSpark(selectSql, true, _ => {})
sql(s"drop table if exists $tableName")
}

test("GLUTEN-7054: Fix exception when CSE meets common alias expression") {
val createTableSql = """
|CREATE TABLE test_tbl_7054 (
| day STRING,
| event_id STRING,
| event STRUCT<
| event_info: MAP<STRING, STRING>
| >
|) STORED AS PARQUET;
|""".stripMargin

val insertDataSql = """
|INSERT INTO test_tbl_7054
|VALUES
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '1', 'action', '8', 'value_vmoney', '100'))),
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '2', 'action', '8', 'value_vmoney', '200'))),
| ('2024-08-27', '011441004',
| STRUCT(MAP('type', '4', 'action', '8', 'value_vmoney', '300')));
|""".stripMargin

val selectSql = """
|SELECT
| COALESCE(day, 'all') AS daytime,
| COALESCE(type, 'all') AS type,
| COALESCE(value_money, 'all') AS value_vmoney,
| SUM(CASE
| WHEN type IN (1, 2) AND action = 8 THEN value_vmoney
| ELSE 0
| END) / 60 AS total_value_vmoney
|FROM (
| SELECT
| day,
| type,
| NVL(CAST(value_vmoney AS BIGINT), 0) AS value_money,
| action,
| type,
| CAST(value_vmoney AS BIGINT) AS value_vmoney
| FROM (
| SELECT
| day,
| event.event_info["type"] AS type,
| event.event_info["action"] AS action,
| event.event_info["value_vmoney"] AS value_vmoney
| FROM test_tbl_7054
| WHERE
| day = '2024-08-27'
| AND event_id = '011441004'
| AND event.event_info["type"] IN (1, 2, 4)
| ) a
|) b
|GROUP BY
| day, type, value_money
|""".stripMargin

spark.sql(createTableSql)
spark.sql(insertDataSql)
runQueryAndCompare(selectSql)(df => checkOperatorCount[ProjectExecTransformer](3)(df))
spark.sql("DROP TABLE test_tbl_7054")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,20 @@ class VeloxWindowExpressionSuite extends WholeStageTransformerSuite {
}
}

test("test overlapping partition and sorting keys") {
runAndCompare(
"""
|WITH t AS (
|SELECT
| l_linenumber,
| row_number() over (partition by l_linenumber order by l_linenumber) as rn
|FROM lineitem
|)
|SELECT * FROM t WHERE rn = 1
|""".stripMargin
) {}
}

test("collect_list / collect_set") {
withTable("t") {
val data = Seq(
Expand Down
8 changes: 8 additions & 0 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,14 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(
}
}
const std::optional<std::string> rowNumberColumnName = std::nullopt;

if (sortingKeys.empty()) {
// Handle if all sorting keys are also used as partition keys.

return std::make_shared<core::RowNumberNode>(
nextPlanNodeId(), partitionKeys, rowNumberColumnName, (int32_t)windowGroupLimitRel.limit(), childNode);
}

return std::make_shared<core::TopNRowNumberNode>(
nextPlanNodeId(),
partitionKeys,
Expand Down

0 comments on commit 40b444f

Please sign in to comment.