Skip to content

Commit

Permalink
Revert "Remove local sort for TopNRowNumber (#6381)"
Browse files Browse the repository at this point in the history
This reverts commit 0448115.
  • Loading branch information
yma11 committed Jul 15, 2024
1 parent 1382f4c commit 719956a
Show file tree
Hide file tree
Showing 15 changed files with 87 additions and 142 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -485,11 +485,7 @@ object VeloxBackendSettings extends BackendSettingsApi {

override def alwaysFailOnMapExpression(): Boolean = true

override def requiredChildOrderingForWindow(): Boolean = {
GlutenConfig.getConf.veloxColumnarWindowType.equals("streaming")
}

override def requiredChildOrderingForWindowGroupLimit(): Boolean = false
override def requiredChildOrderingForWindow(): Boolean = true

override def staticPartitionWriteOnly(): Boolean = true

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ trait BackendSettingsApi {

def alwaysFailOnMapExpression(): Boolean = false

def requiredChildOrderingForWindow(): Boolean = true
def requiredChildOrderingForWindow(): Boolean = false

def requiredChildOrderingForWindowGroupLimit(): Boolean = true
def requiredChildOrderingForWindowGroupLimit(): Boolean = false

def staticPartitionWriteOnly(): Boolean = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,8 @@ object HashAggregateExecBaseTransformer {
case a: SortAggregateExec => a.initialInputBufferOffset
}

def from(agg: BaseAggregateExec): HashAggregateExecBaseTransformer = {
def from(agg: BaseAggregateExec)(
childConverter: SparkPlan => SparkPlan = p => p): HashAggregateExecBaseTransformer = {
BackendsApiManager.getSparkPlanExecApiInstance
.genHashAggregateExecTransformer(
agg.requiredChildDistributionExpressions,
Expand All @@ -194,7 +195,7 @@ object HashAggregateExecBaseTransformer {
agg.aggregateAttributes,
getInitialInputBufferOffset(agg),
agg.resultExpressions,
agg.child
childConverter(agg.child)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.extension.columnar.rewrite.RewrittenNodeWall

import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan}

object SortUtils {
def dropPartialSort(plan: SparkPlan): SparkPlan = plan match {
case RewrittenNodeWall(p) => RewrittenNodeWall(dropPartialSort(p))
case PartialSortLike(child) => child
// from pre/post project-pulling
case ProjectLike(PartialSortLike(ProjectLike(child))) if plan.outputSet == child.outputSet =>
child
case ProjectLike(PartialSortLike(child)) => plan.withNewChildren(Seq(child))
case _ => plan
}
}

object PartialSortLike {
def unapply(plan: SparkPlan): Option[SparkPlan] = plan match {
case sort: SortExecTransformer if !sort.global => Some(sort.child)
case sort: SortExec if !sort.global => Some(sort.child)
case _ => None
}
}

object ProjectLike {
def unapply(plan: SparkPlan): Option[SparkPlan] = plan match {
case project: ProjectExecTransformer => Some(project.child)
case project: ProjectExec => Some(project.child)
case _ => None
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ case class WindowExecTransformer(
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
if (BackendsApiManager.getSettings.requiredChildOrderingForWindow()) {
if (
BackendsApiManager.getSettings.requiredChildOrderingForWindow()
&& GlutenConfig.getConf.veloxColumnarWindowType.equals("streaming")
) {
// Velox StreamingWindow need to require child order.
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
} else {
Seq(Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,24 +66,14 @@ case class WindowGroupLimitExecTransformer(

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
// Velox StreamingTopNRowNumber need to require child order.
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
} else {
Seq(Nil)
}
}

override def outputOrdering: Seq[SortOrder] = {
if (requiredChildOrdering.forall(_.isEmpty)) {
// The Velox backend `TopNRowNumber` does not require child ordering, because it
// uses hash table to store partition and use priority queue to track of top limit rows.
// Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys.
// To be safe, here we do not propagate the ordering.
// TODO: Make the framework aware of grouped data distribution
Nil
} else {
child.outputOrdering
}
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering

override def outputPartitioning: Partitioning = child.outputPartitioning

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
package org.apache.gluten.extension.columnar

import org.apache.gluten.GlutenConfig
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager
import org.apache.gluten.execution.SortExecTransformer

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -33,8 +32,6 @@ import org.apache.spark.sql.execution.{SortExec, SparkPlan}
* SortAggregate with the same key. So, this rule adds local sort back if necessary.
*/
object EnsureLocalSortRequirements extends Rule[SparkPlan] {
private lazy val offload = TransformPreOverrides.apply()

private def addLocalSort(
originalChild: SparkPlan,
requiredOrdering: Seq[SortOrder]): SparkPlan = {
Expand All @@ -43,12 +40,18 @@ object EnsureLocalSortRequirements extends Rule[SparkPlan] {
FallbackTags.add(newChild, "columnar Sort is not enabled in SortExec")
newChild
} else {
val rewrittenPlan = RewriteSparkPlanRulesManager.apply().apply(newChild)
if (rewrittenPlan.eq(newChild) && FallbackTags.nonEmpty(rewrittenPlan)) {
// The sort can not be offloaded
rewrittenPlan
val newChildWithTransformer =
SortExecTransformer(
newChild.sortOrder,
newChild.global,
newChild.child,
newChild.testSpillFrequency)
val validationResult = newChildWithTransformer.doValidate()
if (validationResult.isValid) {
newChildWithTransformer
} else {
offload.apply(rewrittenPlan)
FallbackTags.add(newChild, validationResult)
newChild
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,13 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] {
.genFilterExecTransformer(plan.condition, plan.child)
transformer.doValidate().tagOnFallback(plan)
case plan: HashAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)
val transformer = HashAggregateExecBaseTransformer.from(plan)()
transformer.doValidate().tagOnFallback(plan)
case plan: SortAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)
val transformer = HashAggregateExecBaseTransformer.from(plan)()
transformer.doValidate().tagOnFallback(plan)
case plan: ObjectHashAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)
val transformer = HashAggregateExecBaseTransformer.from(plan)()
transformer.doValidate().tagOnFallback(plan)
case plan: UnionExec =>
val transformer = ColumnarUnionExec(plan.children)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ case class OffloadAggregate() extends OffloadSingleNode with LogLevelUtil {
case _: TransformSupport =>
// If the child is transformable, transform aggregation as well.
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)
HashAggregateExecBaseTransformer.from(plan)()
case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
HashAggregateExecBaseTransformer.from(plan)
HashAggregateExecBaseTransformer.from(plan)()
case _ =>
// If the child is not transformable, do not transform the agg.
FallbackTags.add(plan, "child output schema is empty")
plan
}
} else {
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)
HashAggregateExecBaseTransformer.from(plan)()
}
}
}
Expand Down Expand Up @@ -425,10 +425,10 @@ object OffloadOthers {
ColumnarCoalesceExec(plan.numPartitions, plan.child)
case plan: SortAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)
HashAggregateExecBaseTransformer.from(plan)(SortUtils.dropPartialSort)
case plan: ObjectHashAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)
HashAggregateExecBaseTransformer.from(plan)()
case plan: UnionExec =>
val children = plan.children
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ class EnumeratedApplier(session: SparkSession)
List(
(_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
(spark: SparkSession) => RewriteTransformer(spark),
(_: SparkSession) => EliminateLocalSort,
(_: SparkSession) => EnsureLocalSortRequirements,
(_: SparkSession) => CollapseProjectExecTransformer
) :::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec
object RasOffloadHashAggregate extends RasOffload {
override def offload(node: SparkPlan): SparkPlan = node match {
case agg: HashAggregateExec =>
val out = HashAggregateExecBaseTransformer.from(agg)
val out = HashAggregateExecBaseTransformer.from(agg)()
out
case other => other
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ class HeuristicApplier(session: SparkSession)
List(
(_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
(spark: SparkSession) => RewriteTransformer(spark),
(_: SparkSession) => EliminateLocalSort,
(_: SparkSession) => EnsureLocalSortRequirements,
(_: SparkSession) => CollapseProjectExecTransformer
) :::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.extension.columnar.rewrite

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.SortUtils

import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.JoinType
Expand Down Expand Up @@ -51,8 +52,8 @@ object RewriteJoin extends RewriteSingleNode with JoinSelectionHelper {
smj.joinType,
buildSide,
smj.condition,
smj.left,
smj.right,
SortUtils.dropPartialSort(smj.left),
SortUtils.dropPartialSort(smj.right),
smj.isSkewJoin
)
case _ => plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.execution.{SortExecTransformer, WindowExecTransformer, WindowGroupLimitExecTransformer}
import org.apache.gluten.execution.{WindowExecTransformer, WindowGroupLimitExecTransformer}

import org.apache.spark.sql.GlutenSQLTestsTrait
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -134,9 +134,6 @@ class GlutenSQLWindowFunctionSuite extends SQLWindowFunctionSuite with GlutenSQL
case _ => false
}
)
assert(
getExecutedPlan(df).collect { case s: SortExecTransformer if !s.global => s }.size == 1
)
}
}

Expand Down

0 comments on commit 719956a

Please sign in to comment.