diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index a892b6f313a4e..53eafc4bdb912 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -274,8 +274,15 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla } test("window expression") { + def checkWindowAndSortOperator(df: DataFrame, expectedSortNum: Int): Unit = { + val executedPlan = getExecutedPlan(df) + assert(executedPlan.exists(_.isInstanceOf[WindowExecTransformer])) + assert(executedPlan.count(_.isInstanceOf[SortExecTransformer]) == expectedSortNum) + } + Seq("sort", "streaming").foreach { windowType => + val expectedSortNum = if (windowType == "sort") 0 else 1 withSQLConf("spark.gluten.sql.columnar.backend.velox.window.type" -> windowType) { runQueryAndCompare( "select max(l_partkey) over" + @@ -284,42 +291,42 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla "min(l_comment) over" + " (partition by l_suppkey order by l_linenumber" + " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") { - checkSparkOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum * 2) } runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey" + " RANGE BETWEEN CURRENT ROW AND 2 FOLLOWING) from lineitem ") { - checkSparkOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey" + " RANGE BETWEEN 6 PRECEDING AND CURRENT ROW) from lineitem ") { - checkSparkOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey" + " RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from lineitem ") { - checkSparkOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey" + " RANGE BETWEEN 6 PRECEDING AND 3 PRECEDING) from lineitem ") { - checkSparkOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey" + " RANGE BETWEEN 3 FOLLOWING AND 6 FOLLOWING) from lineitem ") { - checkSparkOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } // DecimalType as order by column is not supported @@ -327,99 +334,111 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla "select min(l_comment) over" + " (partition by l_suppkey order by l_discount" + " RANGE BETWEEN 1 PRECEDING AND CURRENT ROW) from lineitem ") { - checkSparkOperatorMatch[WindowExec] + df => + val executedPlan = getExecutedPlan(df) + assert(executedPlan.count(_.isInstanceOf[WindowExec]) == 1) + // The number of SortExecTransformer should always be 1 + // no matter the window type is streaming or sort, + // because WindowExec is fallback to Vanilla Spark + assert(executedPlan.count(_.isInstanceOf[SortExecTransformer]) == 1) } runQueryAndCompare( "select ntile(4) over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select row_number() over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select rank() over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select dense_rank() over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + " (partition by l_suppkey order by l_orderkey) from lineitem ") { + df => checkWindowAndSortOperator(df, expectedSortNum) + } runQueryAndCompare( "select percent_rank() over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + " (partition by l_suppkey order by l_orderkey) from lineitem ") { + df => checkWindowAndSortOperator(df, expectedSortNum) + } runQueryAndCompare( "select cume_dist() over" + - " (partition by l_suppkey order by l_orderkey) from lineitem ") { _ => } + " (partition by l_suppkey order by l_orderkey) from lineitem ") { + df => checkWindowAndSortOperator(df, expectedSortNum) + } runQueryAndCompare( "select l_suppkey, l_orderkey, nth_value(l_orderkey, 2) over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select l_suppkey, l_orderkey, nth_value(l_orderkey, 2) IGNORE NULLS over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select sum(l_partkey + 1) over" + " (partition by l_suppkey order by l_orderkey) from lineitem") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select max(l_partkey) over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select min(l_partkey) over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select avg(l_partkey) over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select lag(l_orderkey) over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } runQueryAndCompare( "select lead(l_orderkey) over" + " (partition by l_suppkey order by l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } // Test same partition/ordering keys. runQueryAndCompare( "select avg(l_partkey) over" + " (partition by l_suppkey order by l_suppkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } // Test overlapping partition/ordering keys. runQueryAndCompare( "select avg(l_partkey) over" + " (partition by l_suppkey order by l_suppkey, l_orderkey) from lineitem ") { - checkGlutenOperatorMatch[WindowExecTransformer] + df => checkWindowAndSortOperator(df, expectedSortNum) } } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala index 8ed2137f4489f..d990bb5ea5d7c 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/MiscColumnarRules.scala @@ -30,7 +30,10 @@ object MiscColumnarRules { object TransformPreOverrides { def apply(): TransformPreOverrides = { TransformPreOverrides( - List(OffloadFilter()), + List( + OffloadFilter(), + OffloadWindow() + ), List( OffloadOthers(), OffloadAggregate(), diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 8cd2a5fb67bda..d0e2c8f7e0b4b 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -26,6 +26,7 @@ import org.apache.gluten.utils.{LogLevelUtil, PlanUtil} import org.apache.spark.api.python.EvalPythonExecTransformer import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.datasources.WriteFilesExec @@ -232,6 +233,38 @@ case class OffloadFilter() extends OffloadSingleNode with LogLevelUtil { } } +case class OffloadWindow() extends OffloadSingleNode with LogLevelUtil { + import OffloadOthers._ + private val replace = new ReplaceSingleNode() + + override def offload(plan: SparkPlan): SparkPlan = plan match { + case window: WindowExec => + if (TransformHints.isNotTransformable(window)) { + return window + } + + val transformer = replace.doReplace(window) + val newChild = transformer.children.head match { + case SortExec(_, false, child, _) + if outputOrderSatisfied(child, transformer.requiredChildOrdering) => + child + case p @ ProjectExec(_, SortExec(_, false, child, _)) + if outputOrderSatisfied(child, transformer.requiredChildOrdering) => + p.copy(child = child) + case children => children + } + transformer.withNewChildren(Seq(newChild)) + case other => other + } + + private def outputOrderSatisfied(child: SparkPlan, required: Seq[Seq[SortOrder]]): Boolean = { + Seq(child).zip(required).forall { + case (child, requiredOrdering) => + SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering) + } + } +} + // Other transformations. case class OffloadOthers() extends OffloadSingleNode with LogLevelUtil { import OffloadOthers._