Skip to content

Commit

Permalink
Remove local sort for TopNRowNumber
Browse files Browse the repository at this point in the history
  • Loading branch information
ulysses-you committed Jul 10, 2024
1 parent be57db8 commit 58a800c
Show file tree
Hide file tree
Showing 13 changed files with 106 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ trait BackendSettingsApi {

def requiredChildOrderingForWindow(): Boolean = false

def requiredChildOrderingForWindowGroupLimit(): Boolean = false

def staticPartitionWriteOnly(): Boolean = false

def supportTransformWriteFiles: Boolean = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ object HashAggregateExecBaseTransformer {
case a: SortAggregateExec => a.initialInputBufferOffset
}

def from(agg: BaseAggregateExec)(
childConverter: SparkPlan => SparkPlan = p => p): HashAggregateExecBaseTransformer = {
def from(agg: BaseAggregateExec): HashAggregateExecBaseTransformer = {
BackendsApiManager.getSparkPlanExecApiInstance
.genHashAggregateExecTransformer(
agg.requiredChildDistributionExpressions,
Expand All @@ -195,7 +194,7 @@ object HashAggregateExecBaseTransformer {
agg.aggregateAttributes,
getInitialInputBufferOffset(agg),
agg.resultExpressions,
childConverter(agg.child)
agg.child
)
}
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.gluten.substrait.SubstraitContext
import org.apache.gluten.substrait.extensions.ExtensionBuilder
import org.apache.gluten.substrait.rel.{RelBuilder, RelNode}

import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, SortOrder}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.window.{Final, Partial, WindowGroupLimitMode}
Expand Down Expand Up @@ -64,16 +64,12 @@ case class WindowGroupLimitExecTransformer(
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
if (BackendsApiManager.getSettings.requiredChildOrderingForWindowGroupLimit()) {
// Velox StreamingTopNRowNumber need to require child order.
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)
} else {
Seq(Nil)
}
}

override def outputOrdering: Seq[SortOrder] = child.outputOrdering
// The Velox backend `TopNRowNumber` does not require child ordering, because it
// uses hash table to store partition and use priority queue to track of top limit rows.
// Ideally, the output of `TopNRowNumber` is unordered but it is grouped for partition keys.
// To be safe, here we do not propagate the ordering.
// TODO: Make the framework aware of grouped data distribution
override def outputOrdering: Seq[SortOrder] = Nil

override def outputPartitioning: Partitioning = child.outputPartitioning

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar

import org.apache.gluten.execution.{ProjectExecTransformer, SortExecTransformer}

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan}

/**
* This rule is used to eliminate unnecessary local sort.
*
* This could happen if:
* - We convert sort merge join to shuffled hash join
* - We offload SortAggregate to native hash aggregate
* - We offload WindowGroupLimit to native TopNRowNumber
*/
object EliminateLocalSort extends Rule[SparkPlan] {
override def apply(plan: SparkPlan): SparkPlan = {
plan.transformDown {
case p =>
val requiredChildOrdering = p.requiredChildOrdering
assert(requiredChildOrdering.size == p.children.size)
val newChildren = p.children.zipWithIndex.map {
case (SortWithChild(sort, gChild), i) =>
if (SortOrder.orderingSatisfies(gChild.outputOrdering, requiredChildOrdering(i))) {
gChild
} else {
sort
}
case p => p._1
}
p.withNewChildren(newChildren)
}
}
}

object SortWithChild {
def unapply(plan: SparkPlan): Option[(SparkPlan, SparkPlan)] = {
plan match {
case p1 @ ProjectExec(_, SortExecTransformer(_, false, p2: ProjectExec, _))
if p1.outputSet == p2.child.outputSet =>
Some((p1, p2.child))
case p1 @ ProjectExecTransformer(
_,
SortExecTransformer(_, false, p2: ProjectExecTransformer, _))
if p1.outputSet == p2.child.outputSet =>
Some((p1, p2.child))
case s @ SortExec(_, false, child, _) =>
Some((s, child))
case s @ SortExecTransformer(_, false, child, _) =>
Some((s, child))
case _ => None
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
package org.apache.gluten.extension.columnar

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.SortExecTransformer
import org.apache.gluten.extension.columnar.OffloadOthers.ReplaceSingleNode
import org.apache.gluten.extension.columnar.rewrite.RewriteSparkPlanRulesManager

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -32,6 +33,8 @@ import org.apache.spark.sql.execution.{SortExec, SparkPlan}
* SortAggregate with the same key. So, this rule adds local sort back if necessary.
*/
object EnsureLocalSortRequirements extends Rule[SparkPlan] {
private lazy val offload = new ReplaceSingleNode()

private def addLocalSort(
originalChild: SparkPlan,
requiredOrdering: Seq[SortOrder]): SparkPlan = {
Expand All @@ -40,18 +43,12 @@ object EnsureLocalSortRequirements extends Rule[SparkPlan] {
FallbackTags.add(newChild, "columnar Sort is not enabled in SortExec")
newChild
} else {
val newChildWithTransformer =
SortExecTransformer(
newChild.sortOrder,
newChild.global,
newChild.child,
newChild.testSpillFrequency)
val validationResult = newChildWithTransformer.doValidate()
if (validationResult.isValid) {
newChildWithTransformer
val rewrittenPlan = RewriteSparkPlanRulesManager.apply().apply(newChild)
if (rewrittenPlan.eq(newChild) && FallbackTags.nonEmpty(rewrittenPlan)) {
// The sort can not be offloaded
rewrittenPlan
} else {
FallbackTags.add(newChild, validationResult)
newChild
rewrittenPlan.transform { case p => offload.doReplace(p) }
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,13 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] {
.genFilterExecTransformer(plan.condition, plan.child)
transformer.doValidate().tagOnFallback(plan)
case plan: HashAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)()
val transformer = HashAggregateExecBaseTransformer.from(plan)
transformer.doValidate().tagOnFallback(plan)
case plan: SortAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)()
val transformer = HashAggregateExecBaseTransformer.from(plan)
transformer.doValidate().tagOnFallback(plan)
case plan: ObjectHashAggregateExec =>
val transformer = HashAggregateExecBaseTransformer.from(plan)()
val transformer = HashAggregateExecBaseTransformer.from(plan)
transformer.doValidate().tagOnFallback(plan)
case plan: UnionExec =>
val transformer = ColumnarUnionExec(plan.children)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,17 @@ case class OffloadAggregate() extends OffloadSingleNode with LogLevelUtil {
case _: TransformSupport =>
// If the child is transformable, transform aggregation as well.
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)()
HashAggregateExecBaseTransformer.from(plan)
case p: SparkPlan if PlanUtil.isGlutenTableCache(p) =>
HashAggregateExecBaseTransformer.from(plan)()
HashAggregateExecBaseTransformer.from(plan)
case _ =>
// If the child is not transformable, do not transform the agg.
FallbackTags.add(plan, "child output schema is empty")
plan
}
} else {
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)()
HashAggregateExecBaseTransformer.from(plan)
}
}
}
Expand Down Expand Up @@ -425,10 +425,10 @@ object OffloadOthers {
ColumnarCoalesceExec(plan.numPartitions, plan.child)
case plan: SortAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)(SortUtils.dropPartialSort)
HashAggregateExecBaseTransformer.from(plan)
case plan: ObjectHashAggregateExec =>
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
HashAggregateExecBaseTransformer.from(plan)()
HashAggregateExecBaseTransformer.from(plan)
case plan: UnionExec =>
val children = plan.children
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class EnumeratedApplier(session: SparkSession)
List(
(_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
(spark: SparkSession) => RewriteTransformer(spark),
(_: SparkSession) => EliminateLocalSort,
(_: SparkSession) => EnsureLocalSortRequirements,
(_: SparkSession) => CollapseProjectExecTransformer
) :::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.aggregate.HashAggregateExec
object RasOffloadHashAggregate extends RasOffload {
override def offload(node: SparkPlan): SparkPlan = node match {
case agg: HashAggregateExec =>
val out = HashAggregateExecBaseTransformer.from(agg)()
val out = HashAggregateExecBaseTransformer.from(agg)
out
case other => other
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ class HeuristicApplier(session: SparkSession)
List(
(_: SparkSession) => RemoveNativeWriteFilesSortAndProject(),
(spark: SparkSession) => RewriteTransformer(spark),
(_: SparkSession) => EliminateLocalSort,
(_: SparkSession) => EnsureLocalSortRequirements,
(_: SparkSession) => CollapseProjectExecTransformer
) :::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.gluten.extension.columnar.rewrite

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.SortUtils

import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.JoinType
Expand Down Expand Up @@ -52,8 +51,8 @@ object RewriteJoin extends RewriteSingleNode with JoinSelectionHelper {
smj.joinType,
buildSide,
smj.condition,
SortUtils.dropPartialSort(smj.left),
SortUtils.dropPartialSort(smj.right),
smj.left,
smj.right,
smj.isSkewJoin
)
case _ => plan
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.execution.{WindowExecTransformer, WindowGroupLimitExecTransformer}
import org.apache.gluten.execution.{SortExecTransformer, WindowExecTransformer, WindowGroupLimitExecTransformer}

import org.apache.spark.sql.GlutenSQLTestsTrait
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -134,6 +134,9 @@ class GlutenSQLWindowFunctionSuite extends SQLWindowFunctionSuite with GlutenSQL
case _ => false
}
)
assert(
getExecutedPlan(df).collect { case s: SortExecTransformer => s }.size == 1
)
}
}

Expand Down

0 comments on commit 58a800c

Please sign in to comment.