Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jul 24, 2024
1 parent 6a9b55b commit d46d4e9
Showing 1 changed file with 92 additions and 14 deletions.
106 changes: 92 additions & 14 deletions dev/diffs/3.4.3.diff
Original file line number Diff line number Diff line change
Expand Up @@ -1144,65 +1144,143 @@ index 47679ed7865..9ffbaecb98e 100644
assert(collectWithSubqueries(plan) { case s: SortAggregateExec => s }.length == sortAggCount)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index ac710c32296..e163c1a6a76 100644
index ac710c32296..baae214c6ee 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution

import org.apache.spark.sql.{Dataset, QueryTest, Row, SaveMode}
import org.apache.spark.sql.catalyst.expressions.codegen.{ByteCodeStats, CodeAndComment, CodeGenerator}
+import org.apache.spark.sql.comet.CometSortMergeJoinExec
+import org.apache.spark.sql.comet.{CometHashJoinExec, CometSortMergeJoinExec}
import org.apache.spark.sql.execution.adaptive.DisableAdaptiveExecutionSuite
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
@@ -224,6 +225,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
@@ -169,6 +170,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_HASH"), $"k1" === $"k2")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+ case _: CometHashJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4)))

@@ -177,6 +179,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.join(df3.hint("SHUFFLE_HASH"), $"k1" === $"k3")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) => true
+ case _: CometHashJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, 4, 4)))
@@ -193,6 +196,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinUniqueDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinUniqueDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, 4),
Row(null, 5), Row(null, 6), Row(null, 7), Row(null, 8), Row(null, 9)))
@@ -203,6 +208,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinNonUniqueDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinNonUniqueDF, Seq(Row(0, 0), Row(0, 3), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 4), Row(1, 7), Row(2, 2), Row(2, 5), Row(2, 8), Row(3, null), Row(4, null)))
@@ -213,6 +220,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(joinWithNonEquiDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 1)
checkAnswer(joinWithNonEquiDF, Seq(Row(0, 0), Row(0, 6), Row(0, 9), Row(1, 1),
Row(1, 7), Row(2, 2), Row(2, 8), Row(3, null), Row(4, null), Row(null, 3), Row(null, 4),
@@ -224,6 +233,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : ShuffledHashJoinExec) if hint == "SHUFFLE_HASH" => true
case WholeStageCodegenExec(_ : SortMergeJoinExec) if hint == "SHUFFLE_MERGE" => true
+ case _: CometHashJoinExec if hint == "SHUFFLE_HASH" => true
+ case _: CometSortMergeJoinExec if hint == "SHUFFLE_MERGE" => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, null), Row(2, 2, 2), Row(3, 3, null), Row(4, 4, null),
@@ -258,6 +260,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
@@ -241,6 +252,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneLeftOuterJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_outer")
assert(oneLeftOuterJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneLeftOuterJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(4, null),
Row(5, null), Row(6, null), Row(7, null), Row(8, null), Row(9, null)))
@@ -249,6 +261,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneRightOuterJoinDF = df2.join(df3.hint("SHUFFLE_MERGE"), $"k2" === $"k3", "right_outer")
assert(oneRightOuterJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneRightOuterJoinDF, Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(3, 3), Row(null, 4),
Row(null, 5)))
@@ -258,6 +271,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "right_outer")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF,
Seq(Row(0, 0, 0), Row(1, 1, 1), Row(2, 2, 2), Row(3, 3, 3), Row(4, null, 4), Row(5, null, 5),
@@ -280,8 +283,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
@@ -273,6 +287,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_semi")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(0), Row(1), Row(2), Row(3)))

@@ -280,8 +295,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val twoJoinsDF = df3.join(df2.hint("SHUFFLE_MERGE"), $"k3" === $"k2", "left_semi")
.join(df1.hint("SHUFFLE_MERGE"), $"k3" === $"k1", "left_semi")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(0), Row(1), Row(2), Row(3)))
}
@@ -302,8 +304,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
@@ -295,6 +310,7 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val oneJoinDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
assert(oneJoinDF.queryExecution.executedPlan.collect {
case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) => true
+ case _: CometSortMergeJoinExec => true
}.size === 1)
checkAnswer(oneJoinDF, Seq(Row(4), Row(5), Row(6), Row(7), Row(8), Row(9)))

@@ -302,8 +318,8 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val twoJoinsDF = df1.join(df2.hint("SHUFFLE_MERGE"), $"k1" === $"k2", "left_anti")
.join(df3.hint("SHUFFLE_MERGE"), $"k1" === $"k3", "left_anti")
assert(twoJoinsDF.queryExecution.executedPlan.collect {
- case WholeStageCodegenExec(ProjectExec(_, _ : SortMergeJoinExec)) |
- WholeStageCodegenExec(_ : SortMergeJoinExec) => true
+ case _: SortMergeJoinExec => true
+ case _: CometSortMergeJoinExec => true
}.size === 2)
checkAnswer(twoJoinsDF, Seq(Row(6), Row(7), Row(8), Row(9)))
}
@@ -436,7 +437,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
val plan = df.queryExecution.executedPlan
assert(plan.exists(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
@@ -433,10 +449,6 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession

test("Sort should be included in WholeStageCodegen") {
val df = spark.range(3, 0, -1).toDF().sort(col("id"))
- val plan = df.queryExecution.executedPlan
- assert(plan.exists(p =>
- p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SortExec]))
+ p.asInstanceOf[WholeStageCodegenExec].collect {
+ case _: SortExec => true
+ }.nonEmpty))
assert(df.collect() === Array(Row(1), Row(2), Row(3)))
}

@@ -616,7 +619,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
@@ -616,7 +628,9 @@ class WholeStageCodegenSuite extends QueryTest with SharedSparkSession
.write.mode(SaveMode.Overwrite).parquet(path)

withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "255",
Expand Down

0 comments on commit d46d4e9

Please sign in to comment.