Skip to content

Commit

Permalink
[VL] Remove a limit for BHJ in stage fallback policy (apache#7105)
Browse files Browse the repository at this point in the history
  • Loading branch information
PHILO-HE authored and shamirchen committed Oct 14, 2024
1 parent cfbafeb commit 0d34941
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.extension.GlutenPlan

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarShuffleExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}

class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper {
protected val rootPath: String = getClass.getResource("/").getPath
Expand Down Expand Up @@ -106,35 +106,36 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
}
}

test("fallback with bhj") {
withSQLConf(GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "2") {
test("offload BroadcastExchange and fall back BHJ") {
withSQLConf(
"spark.gluten.sql.columnar.broadcastJoin" -> "false"
) {
runQueryAndCompare(
"""
|SELECT *, java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1
|LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1
|SELECT java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1
|LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1 limit 10
|""".stripMargin
) {
df =>
val plan = df.queryExecution.executedPlan
val bhj = find(plan) {
val columnarBhj = find(plan) {
case _: BroadcastHashJoinExecTransformerBase => true
case _ => false
}
assert(bhj.isDefined)
val columnarToRow = collectColumnarToRow(bhj.get)
assert(columnarToRow == 0)
assert(!columnarBhj.isDefined)

val wholeQueryColumnarToRow = collectColumnarToRow(plan)
assert(wholeQueryColumnarToRow == 1)
}
val vanillaBhj = find(plan) {
case _: BroadcastHashJoinExec => true
case _ => false
}
assert(vanillaBhj.isDefined)

// before the fix, it would fail
spark
.sql("""
|SELECT *, java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1
|LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1
|""".stripMargin)
.show()
val columnarBroadcastExchange = find(plan) {
case _: ColumnarBroadcastExchangeExec => true
case _ => false
}
assert(columnarBroadcastExchange.isDefined)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.gluten.extension.columnar

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.BroadcastHashJoinExecTransformerBase
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike, Transitions}
import org.apache.gluten.utils.PlanUtil
Expand All @@ -27,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.exchange.Exchange
Expand Down Expand Up @@ -179,21 +178,6 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
stageFallbackTransitionCost
}

private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = {
def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match {
case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true
case _ => false
}

plan.find {
case j: BroadcastHashJoinExecTransformerBase
if isColumnarBroadcastExchange(j.left) ||
isColumnarBroadcastExchange(j.right) =>
true
case _ => false
}.isDefined
}

private def fallback(plan: SparkPlan): FallbackInfo = {
val fallbackThreshold = if (isAdaptiveContext) {
GlutenConfig.getConf.wholeStageFallbackThreshold
Expand All @@ -210,11 +194,6 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
return FallbackInfo.DO_NOT_FALLBACK()
}

// not safe to fallback row-based BHJ as the broadcast exchange is already columnar
if (hasColumnarBroadcastExchangeWithJoin(plan)) {
return FallbackInfo.DO_NOT_FALLBACK()
}

val transitionCost = countTransitionCost(plan)
val fallbackTransitionCost = if (isAdaptiveContext) {
countStageFallbackTransitionCost(plan)
Expand Down

0 comments on commit 0d34941

Please sign in to comment.