Skip to content

Commit

Permalink
[EMR] Support fallback whole stage according to the number of window …
Browse files Browse the repository at this point in the history
…sort in stage
  • Loading branch information
liujiayi771 authored and kerwin-zk committed Nov 14, 2024
1 parent c989524 commit 4610784
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ private object VeloxRuleApi {
// Gluten columnar: Fallback policies.
injector.injectFallbackPolicy(
c => ExpandFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan()))
injector.injectFallbackPolicy(
c => WindowSortFallbackPolicy(c.ac.isAdaptiveContext(), c.ac.originalPlan()))

// Gluten columnar: Post rules.
injector.injectPost(c => RemoveTopmostColumnarToRow(c.session, c.ac.isAdaptiveContext()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl

override def beforeAll(): Unit = {
super.beforeAll()
createTPCHNotNullTables()

spark
.range(100)
Expand All @@ -65,7 +66,9 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
spark.sql("drop table tmp1")
spark.sql("drop table tmp2")
spark.sql("drop table tmp3")

if (TPCHTableDataFrames != null) {
TPCHTableDataFrames.keys.foreach(table => spark.sql(s"DROP TABLE IF EXISTS $table"))
}
super.afterAll()
}

Expand Down Expand Up @@ -271,4 +274,24 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
}
}
}

test("test window sort fallback threshold") {
withSQLConf(GlutenConfig.COLUMNAR_WINDOWSORT_FALLBACK_THRESHOLD.key -> "2") {
runQueryAndCompare(
"""
|SELECT
| min(l_partkey) OVER
| (PARTITION BY l_orderkey ORDER BY l_suppkey
| ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) w1,
| max(l_linenumber) OVER
| (PARTITION BY l_orderkey ORDER BY l_receiptdate
| ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) w2
|FROM lineitem ORDER BY l_orderkey limit 10
|""".stripMargin) {
df =>
val plan = df.queryExecution.executedPlan
assert(collect(plan) { case f: WindowExecTransformer => f }.isEmpty)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.{BroadcastHashJoinExecTransformerBase, SortExecTransformer, WindowExecTransformer}
import org.apache.gluten.extension.columnar.heuristic.FallbackNode
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, Transitions}

import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec}
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.exchange.Exchange

/**
* @param isAdaptiveContext
* If is inside AQE
* @param originalPlan
* The vanilla SparkPlan without apply gluten transform rules
*/
case class WindowSortFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkPlan)
extends Rule[SparkPlan] {

private def getWindowSortCount(plan: SparkPlan): Int = {
var windowSortCount = 0
def countFallbackInternal(plan: SparkPlan): Unit = {
plan match {
case _: QueryStageExec => // Another stage.
case _: CommandResultExec | _: ExecutedCommandExec => // ignore
// we plan exchange to columnar exchange in columnar rules and the exchange does not
// support columnar, so the output columnar is always false in AQE postStageCreationRules
case ColumnarToRowLike(s: Exchange) if isAdaptiveContext =>
countFallbackInternal(s)
case _ @WindowExecTransformer(_, _, _, s: SortExecTransformer) =>
windowSortCount += 1
countFallbackInternal(s.child)
case p => p.children.foreach(countFallbackInternal)
}
}
countFallbackInternal(plan)
windowSortCount
}

private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = {
def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match {
case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true
case _ => false
}

plan.find {
case j: BroadcastHashJoinExecTransformerBase
if isColumnarBroadcastExchange(j.left) ||
isColumnarBroadcastExchange(j.right) =>
true
case _ => false
}.isDefined
}

private def fallback(plan: SparkPlan): FallbackInfo = {
val fallbackThreshold = if (isAdaptiveContext) {
GlutenConfig.getConf.wholeStageWindowSortFallbackThreshold
} else if (plan.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined) {
// if we are here, that means we are now at `QueryExecution.preparations` and
// AQE is actually not applied. We do nothing for this case, and later in
// AQE we can check `wholeStageFallbackThreshold`.
return FallbackInfo.DO_NOT_FALLBACK()
} else {
// AQE is not applied, so we use the whole query threshold to check if should fallback
GlutenConfig.getConf.queryFallbackThreshold
}
if (fallbackThreshold < 0) {
return FallbackInfo.DO_NOT_FALLBACK()
}

// not safe to fallback row-based BHJ as the broadcast exchange is already columnar
if (hasColumnarBroadcastExchangeWithJoin(plan)) {
return FallbackInfo.DO_NOT_FALLBACK()
}

val windowSortCount = getWindowSortCount(plan)

if (windowSortCount >= fallbackThreshold) {
FallbackInfo(
Some(
s"Sort Window Fallback policy is taking effect, Sort Window count: $windowSortCount, " +
s"threshold: $fallbackThreshold"),
windowSortCount
)
} else {
FallbackInfo(windowSortCount = windowSortCount)
}
}

private def fallbackToRowBasedPlan(outputsColumnar: Boolean): SparkPlan = {
val planWithTransitions = Transitions.insertTransitions(originalPlan, outputsColumnar)
planWithTransitions
}

override def apply(plan: SparkPlan): SparkPlan = {
val outputsColumnar = plan.supportsColumnar
val fallbackInfo = fallback(plan)
if (fallbackInfo.shouldFallback) {
val vanillaSparkPlan = fallbackToRowBasedPlan(outputsColumnar)
FallbackTags.addRecursively(
vanillaSparkPlan,
FallbackTag.Exclusive(fallbackInfo.reason.getOrElse("Unknown reason")))
FallbackNode(vanillaSparkPlan)
} else {
plan
}
}

case class FallbackInfo(reason: Option[String] = None, windowSortCount: Int = 0) {
def shouldFallback: Boolean = reason.isDefined
}

object FallbackInfo {
def DO_NOT_FALLBACK(): FallbackInfo = FallbackInfo()
}
}
11 changes: 11 additions & 0 deletions shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ class GlutenConfig(conf: SQLConf) extends Logging {

def wholeStageFallbackThreshold: Int = conf.getConf(COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD)

def wholeStageWindowSortFallbackThreshold: Int =
conf.getConf(COLUMNAR_WINDOWSORT_FALLBACK_THRESHOLD)

def queryFallbackThreshold: Int = conf.getConf(COLUMNAR_QUERY_FALLBACK_THRESHOLD)

def fallbackIgnoreRowToColumnar: Boolean = conf.getConf(COLUMNAR_FALLBACK_IGNORE_ROW_TO_COLUMNAR)
Expand Down Expand Up @@ -1087,6 +1090,14 @@ object GlutenConfig {
.intConf
.createWithDefault(12)

val COLUMNAR_WINDOWSORT_FALLBACK_THRESHOLD =
buildConf("spark.gluten.sql.columnar.windowsort.fallback.threshold")
.internal()
.doc("The threshold for whether whole stage will fall back in AQE supported case " +
"by counting the number of Window & Sort node.")
.intConf
.createWithDefault(5)

val COLUMNAR_SCAN_ONLY_ENABLED =
buildConf("spark.gluten.sql.columnar.scanOnly")
.internal()
Expand Down

0 comments on commit 4610784

Please sign in to comment.