Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VL] Remove a limit for BHJ in stage fallback policy #7105

Merged
merged 4 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ import org.apache.gluten.GlutenConfig
import org.apache.gluten.extension.GlutenPlan

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarShuffleExchangeExec, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec}
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}

class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper {
protected val rootPath: String = getClass.getResource("/").getPath
Expand Down Expand Up @@ -106,35 +106,36 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl
}
}

test("fallback with bhj") {
withSQLConf(GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "2") {
test("offload BroadcastExchange and fall back BHJ") {
withSQLConf(
"spark.gluten.sql.columnar.broadcastJoin" -> "false"
) {
runQueryAndCompare(
"""
|SELECT *, java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1
|LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1
|SELECT java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1
|LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1 limit 10
|""".stripMargin
) {
df =>
val plan = df.queryExecution.executedPlan
val bhj = find(plan) {
val columnarBhj = find(plan) {
case _: BroadcastHashJoinExecTransformerBase => true
case _ => false
}
assert(bhj.isDefined)
val columnarToRow = collectColumnarToRow(bhj.get)
assert(columnarToRow == 0)
assert(!columnarBhj.isDefined)

val wholeQueryColumnarToRow = collectColumnarToRow(plan)
assert(wholeQueryColumnarToRow == 1)
}
val vanillaBhj = find(plan) {
case _: BroadcastHashJoinExec => true
case _ => false
}
assert(vanillaBhj.isDefined)

// before the fix, it would fail
spark
.sql("""
|SELECT *, java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1
|LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1
|""".stripMargin)
.show()
val columnarBroadcastExchange = find(plan) {
case _: ColumnarBroadcastExchangeExec => true
case _ => false
}
assert(columnarBroadcastExchange.isDefined)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.apache.gluten.extension.columnar

import org.apache.gluten.GlutenConfig
import org.apache.gluten.execution.BroadcastHashJoinExecTransformerBase
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike, Transitions}
import org.apache.gluten.utils.PlanUtil
Expand All @@ -27,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec}
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.execution.command.ExecutedCommandExec
import org.apache.spark.sql.execution.exchange.Exchange
Expand Down Expand Up @@ -179,21 +178,6 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
stageFallbackTransitionCost
}

private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = {
def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match {
case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true
case _ => false
}

plan.find {
case j: BroadcastHashJoinExecTransformerBase
if isColumnarBroadcastExchange(j.left) ||
isColumnarBroadcastExchange(j.right) =>
true
case _ => false
}.isDefined
}

private def fallback(plan: SparkPlan): FallbackInfo = {
val fallbackThreshold = if (isAdaptiveContext) {
GlutenConfig.getConf.wholeStageFallbackThreshold
Expand All @@ -210,11 +194,6 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP
return FallbackInfo.DO_NOT_FALLBACK()
}

// not safe to fallback row-based BHJ as the broadcast exchange is already columnar
if (hasColumnarBroadcastExchangeWithJoin(plan)) {
return FallbackInfo.DO_NOT_FALLBACK()
}

val transitionCost = countTransitionCost(plan)
val fallbackTransitionCost = if (isAdaptiveContext) {
countStageFallbackTransitionCost(plan)
Expand Down
Loading