Skip to content

Commit

Permalink
Remove sort before window if window type is 'sort'
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGuangxin committed Jun 25, 2024
1 parent 1e06169 commit b4bea80
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,15 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
}

test("window expression") {
def checkWindowAndSortOperator(df: DataFrame, expectedSortNum: Int): Unit = {
val executedPlan = getExecutedPlan(df)
assert(executedPlan.exists(_.isInstanceOf[WindowExecTransformer]))
assert(executedPlan.count(_.isInstanceOf[SortExecTransformer]) == expectedSortNum)
}

Seq("sort", "streaming").foreach {
windowType =>
val expectedSortNum = if (windowType == "sort") 0 else 1
withSQLConf("spark.gluten.sql.columnar.backend.velox.window.type" -> windowType) {
runQueryAndCompare(
"select max(l_partkey) over" +
Expand All @@ -284,142 +291,154 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
"min(l_comment) over" +
" (partition by l_suppkey order by l_linenumber" +
" RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum * 2)
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 6 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 6 PRECEDING AND 3 PRECEDING) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey" +
" RANGE BETWEEN 3 FOLLOWING AND 6 FOLLOWING) from lineitem ") {
checkSparkOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

// DecimalType as order by column is not supported
runQueryAndCompare(
"select min(l_comment) over" +
" (partition by l_suppkey order by l_discount" +
" RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") {
checkSparkOperatorMatch[WindowExec]
df =>
val executedPlan = getExecutedPlan(df)
assert(executedPlan.count(_.isInstanceOf[WindowExec]) == 1)
// The number of SortExecTransformer should always be 1
// no matter the window type is streaming or sort,
// because WindowExec is fallback to Vanilla Spark
assert(executedPlan.count(_.isInstanceOf[SortExecTransformer]) == 1)
}

runQueryAndCompare(
"select ntile(4) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select row_number() over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select rank() over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select dense_rank() over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => }
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select percent_rank() over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => }
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select cume_dist() over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => }
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select l_suppkey, l_orderkey, nth_value(l_orderkey, 2) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select l_suppkey, l_orderkey, nth_value(l_orderkey, 2) IGNORE NULLS over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select sum(l_partkey + 1) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select max(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select min(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select avg(l_partkey) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select lag(l_orderkey) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

runQueryAndCompare(
"select lead(l_orderkey) over" +
" (partition by l_suppkey order by l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

// Test same partition/ordering keys.
runQueryAndCompare(
"select avg(l_partkey) over" +
" (partition by l_suppkey order by l_suppkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}

// Test overlapping partition/ordering keys.
runQueryAndCompare(
"select avg(l_partkey) over" +
" (partition by l_suppkey order by l_suppkey, l_orderkey) from lineitem ") {
checkGlutenOperatorMatch[WindowExecTransformer]
df => checkWindowAndSortOperator(df, expectedSortNum)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ object MiscColumnarRules {
object TransformPreOverrides {
def apply(): TransformPreOverrides = {
TransformPreOverrides(
List(OffloadFilter()),
List(
OffloadFilter(),
OffloadWindow()
),
List(
OffloadOthers(),
OffloadAggregate(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.gluten.utils.{LogLevelUtil, PlanUtil}

import org.apache.spark.api.python.EvalPythonExecTransformer
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec}
import org.apache.spark.sql.execution.datasources.WriteFilesExec
Expand Down Expand Up @@ -232,6 +233,38 @@ case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil {
}
}

case class OffloadWindow() extends OffloadSingleNode with LogLevelUtil {
import OffloadOthers._
private val replace = new ReplaceSingleNode()

override def offload(plan: SparkPlan): SparkPlan = plan match {
case window: WindowExec =>
if (TransformHints.isNotTransformable(window)) {
return window
}

val transformer = replace.doReplace(window)
val newChild = transformer.children.head match {
case SortExec(_, false, child, _)
if outputOrderSatisfied(child, transformer.requiredChildOrdering) =>
child
case p @ ProjectExec(_, SortExec(_, false, child, _))
if outputOrderSatisfied(child, transformer.requiredChildOrdering) =>
p.copy(child = child)
case children => children
}
transformer.withNewChildren(Seq(newChild))
case other => other
}

private def outputOrderSatisfied(child: SparkPlan, required: Seq[Seq[SortOrder]]): Boolean = {
Seq(child).zip(required).forall {
case (child, requiredOrdering) =>
SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)
}
}
}

// Other transformations.
case class OffloadOthers() extends OffloadSingleNode with LogLevelUtil {
import OffloadOthers._
Expand Down

0 comments on commit b4bea80

Please sign in to comment.