Skip to content

Commit

Permalink
[CORE] Consider the cost when applying stage fallback policy (#3569)
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE authored Nov 2, 2023
1 parent 697b2bd commit 81fae1e
Showing 1 changed file with 97 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ package io.glutenproject.extension

import io.glutenproject.GlutenConfig
import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.execution.{BasicScanExecTransformer, BroadcastHashJoinExecTransformer}
import io.glutenproject.execution.BroadcastHashJoinExecTransformer
import io.glutenproject.extension.columnar.TransformHints

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarShuffleExchangeExec, ColumnarToRowExec, CommandResultExec, LeafExecNode, RowToColumnarExec, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarShuffleExchangeExec, ColumnarToRowExec, CommandResultExec, LeafExecNode, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AQEShuffleReadExec, BroadcastQueryStageExec, ColumnarAQEShuffleReadExec, QueryStageExec, ShuffleQueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}

Expand Down Expand Up @@ -67,78 +68,115 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec}
case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkPlan)
extends Rule[SparkPlan] {

private def ignoreOneColumnarToRow(plan: SparkPlan): Boolean = {
// For native file scan, there is a chance that we can totally fallback to
// vanilla Spark file scan, but for a materialized `ColumnarShuffleExchangeExec`
// there must be a `ColumnarToRowExec` if we decide to fallback.
val hasScan = plan.collectLeaves().exists {
case _: BasicScanExecTransformer => true
case _ => false
}

// spotless:off
// It has no meaning to add ColumnarToRow eagerly if we will add it finally.
// So we ignore the ColumnarToRow to avoid fallback eagerly.
// For example: 1 is better than 2
//
// 1. We first run native operator then add ColumnarToRow
// ColumnarExchange
// |
// HashAggregateTransformer
// |
// ColumnarToRow
//
// 2. We first add ColumnarToRow then run vanilla Spark operator
// ColumnarExchange
// |
// ColumnarToRow
// |
// HashAggregate
//
// spotless:on
var numColumnarToRow = 0
var numRowToColumnar = 0
plan.foreach {
case _: ColumnarToRowExec => numColumnarToRow += 1
case _: RowToColumnarExec => numRowToColumnar += 1
case _ =>
}
!hasScan && numColumnarToRow == 1 && numRowToColumnar == 0
}

private def countFallbacks(plan: SparkPlan): Int = {
private def countFallback(plan: SparkPlan): Int = {
var fallbacks = 0
def countFallback(plan: SparkPlan): Unit = {
def countFallbackInternal(plan: SparkPlan): Unit = {
plan match {
case _: QueryStageExec => // Another stage.
case _: CommandResultExec | _: ExecutedCommandExec => // ignore
// we plan exchange to columnar exchange in columnar rules and the exchange does not
// support columnar, so the output columnar is always false in AQE postStageCreationRules
case ColumnarToRowExec(s: Exchange) if isAdaptiveContext =>
countFallback(s)
countFallbackInternal(s)
case u: UnaryExecNode
if !u.isInstanceOf[GlutenPlan] && InMemoryTableScanHelper.isGlutenTableCache(u.child) =>
// Vanilla Spark plan will call `InMemoryTableScanExec.convertCachedBatchToInternalRow`
// which is a kind of `ColumnarToRowExec`.
fallbacks = fallbacks + 1
countFallback(u.child)
countFallbackInternal(u.child)
case ColumnarToRowExec(p: GlutenPlan) =>
logDebug(s"Find a columnar to row for gluten plan:\n$p")
fallbacks = fallbacks + 1
countFallback(p)
countFallbackInternal(p)
case leafPlan: LeafExecNode if InMemoryTableScanHelper.isGlutenTableCache(leafPlan) =>
case leafPlan: LeafExecNode if !leafPlan.isInstanceOf[GlutenPlan] =>
// Possible fallback for leaf node.
fallbacks = fallbacks + 1
case p => p.children.foreach(countFallback)
case p => p.children.foreach(countFallbackInternal)
}
}
if (!ignoreOneColumnarToRow(plan)) {
countFallback(plan)
}
countFallbackInternal(plan)
fallbacks
}

/**
* When making a stage fall back, it's possible that we need a ColumnarToRow to adapt to last
* stage's columnar output. So we need to evaluate the cost, i.e., the number of required
* ColumnarToRow between entirely fallback stage and last stage(s). Thus, we can avoid possible
* performance degradation caused by fallback policy.
*
* spotless:off
*
* Spark plan before applying fallback policy:
*
* ColumnarExchange
* ----------- | --------------- last stage
* HashAggregateTransformer
* |
* ColumnarToRow
* |
* Project
*
* To illustrate the effect if cost is not taken into account, here is spark plan
* after applying whole stage fallback policy (threshold = 1):
*
* ColumnarExchange
* ----------- | --------------- last stage
* ColumnarToRow
* |
* HashAggregate
* |
* Project
*
* So by considering the cost, the fallback policy will not be applied.
*
* spotless:on
*/
private def countStageFallbackCost(plan: SparkPlan): Int = {
var stageFallbackCost = 0

/**
* 1) Find a Gluten plan whose child is InMemoryTableScanExec. Then, increase stageFallbackCost
* if InMemoryTableScanExec is gluten's table cache and decrease stageFallbackCost if not. 2)
* Find a Gluten plan whose child is QueryStageExec. Then, increase stageFallbackCost if the
* last query stage's plan is GlutenPlan and decrease stageFallbackCost if not.
*/
def countStageFallbackCostInternal(plan: SparkPlan): Unit = {
plan match {
case _: GlutenPlan if plan.children.find(_.isInstanceOf[InMemoryTableScanExec]).isDefined =>
plan.children
.filter(_.isInstanceOf[InMemoryTableScanExec])
.foreach {
// For this case, table cache will internally execute ColumnarToRow if
// we make the stage fall back.
case child if InMemoryTableScanHelper.isGlutenTableCache(child) =>
stageFallbackCost = stageFallbackCost + 1
// For other case, table cache will save internal RowToColumnar if we make
// the stage fall back.
case _ =>
stageFallbackCost = stageFallbackCost - 1
}
case _: GlutenPlan if plan.children.find(_.isInstanceOf[QueryStageExec]).isDefined =>
plan.children
.filter(_.isInstanceOf[QueryStageExec])
.foreach {
case stage: QueryStageExec
if stage.plan.isInstanceOf[GlutenPlan] ||
// For TableCacheQueryStageExec since spark 3.5.
InMemoryTableScanHelper.isGlutenTableCache(stage) =>
stageFallbackCost = stageFallbackCost + 1
// For other cases, RowToColumnar will be removed if stage falls back, so reduce
// the cost.
case _ =>
stageFallbackCost = stageFallbackCost - 1
}
case _ => plan.children.foreach(countStageFallbackCostInternal)
}
}
countStageFallbackCostInternal(plan)
stageFallbackCost
}

private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = {
def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match {
case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true
Expand All @@ -159,7 +197,7 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
GlutenConfig.getConf.wholeStageFallbackThreshold
} else if (plan.find(_.isInstanceOf[AdaptiveSparkPlanExec]).isDefined) {
// if we are here, that means we are now at `QueryExecution.preparations` and
// AQE is actually applied. We do nothing for this case, and later in
// AQE is actually not applied. We do nothing for this case, and later in
// AQE we can check `wholeStageFallbackThreshold`.
return None
} else {
Expand All @@ -175,11 +213,15 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
return None
}

val numFallback = countFallbacks(plan)
if (numFallback >= fallbackThreshold) {
val netFallbackNum = if (isAdaptiveContext) {
countFallback(plan) - countStageFallbackCost(plan)
} else {
countFallback(plan)
}
if (netFallbackNum >= fallbackThreshold) {
Some(
s"Fall back the plan due to fallback number $numFallback, " +
s"threshold $fallbackThreshold")
s"Fallback policy is taking effect, net fallback number: $netFallbackNum, " +
s"threshold: $fallbackThreshold")
} else {
None
}
Expand Down

0 comments on commit 81fae1e

Please sign in to comment.