diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala index 2b40ac54b2c6..0f1256923275 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/FallbackSuite.scala @@ -20,10 +20,10 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.extension.GlutenPlan import org.apache.spark.SparkConf -import org.apache.spark.sql.execution.{ColumnarShuffleExchangeExec, SparkPlan} +import org.apache.spark.sql.execution.{ColumnarBroadcastExchangeExec, ColumnarShuffleExchangeExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, AQEShuffleReadExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec -import org.apache.spark.sql.execution.joins.SortMergeJoinExec +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper { protected val rootPath: String = getClass.getResource("/").getPath @@ -106,35 +106,36 @@ class FallbackSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPl } } - test("fallback with bhj") { - withSQLConf(GlutenConfig.COLUMNAR_WHOLESTAGE_FALLBACK_THRESHOLD.key -> "2") { + test("offload BroadcastExchange and fall back BHJ") { + withSQLConf( + "spark.gluten.sql.columnar.broadcastJoin" -> "false" + ) { runQueryAndCompare( """ - |SELECT *, java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1 - |LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1 + |SELECT java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1 + |LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1 limit 10 |""".stripMargin ) { df => val plan = df.queryExecution.executedPlan - val bhj = find(plan) { + val columnarBhj = find(plan) { case _: BroadcastHashJoinExecTransformerBase => true case _ => false } - assert(bhj.isDefined) - val columnarToRow = collectColumnarToRow(bhj.get) - assert(columnarToRow == 0) + assert(!columnarBhj.isDefined) - val wholeQueryColumnarToRow = collectColumnarToRow(plan) - assert(wholeQueryColumnarToRow == 1) - } + val vanillaBhj = find(plan) { + case _: BroadcastHashJoinExec => true + case _ => false + } + assert(vanillaBhj.isDefined) - // before the fix, it would fail - spark - .sql(""" - |SELECT *, java_method('java.lang.Integer', 'sum', tmp1.c1, tmp2.c1) FROM tmp1 - |LEFT JOIN tmp2 on tmp1.c1 = tmp2.c1 - |""".stripMargin) - .show() + val columnarBroadcastExchange = find(plan) { + case _: ColumnarBroadcastExchangeExec => true + case _ => false + } + assert(columnarBroadcastExchange.isDefined) + } } } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala index 491b54443d67..29e1caae74ff 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/ExpandFallbackPolicy.scala @@ -17,7 +17,6 @@ package org.apache.gluten.extension.columnar import org.apache.gluten.GlutenConfig -import org.apache.gluten.execution.BroadcastHashJoinExecTransformerBase import org.apache.gluten.extension.GlutenPlan import org.apache.gluten.extension.columnar.transition.{ColumnarToRowLike, RowToColumnarLike, Transitions} import org.apache.gluten.utils.PlanUtil @@ -27,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, BroadcastQueryStageExec, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.command.ExecutedCommandExec import org.apache.spark.sql.execution.exchange.Exchange @@ -179,21 +178,6 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP stageFallbackTransitionCost } - private def hasColumnarBroadcastExchangeWithJoin(plan: SparkPlan): Boolean = { - def isColumnarBroadcastExchange(p: SparkPlan): Boolean = p match { - case BroadcastQueryStageExec(_, _: ColumnarBroadcastExchangeExec, _) => true - case _ => false - } - - plan.find { - case j: BroadcastHashJoinExecTransformerBase - if isColumnarBroadcastExchange(j.left) || - isColumnarBroadcastExchange(j.right) => - true - case _ => false - }.isDefined - } - private def fallback(plan: SparkPlan): FallbackInfo = { val fallbackThreshold = if (isAdaptiveContext) { GlutenConfig.getConf.wholeStageFallbackThreshold @@ -210,11 +194,6 @@ case class ExpandFallbackPolicy(isAdaptiveContext: Boolean, originalPlan: SparkP return FallbackInfo.DO_NOT_FALLBACK() } - // not safe to fallback row-based BHJ as the broadcast exchange is already columnar - if (hasColumnarBroadcastExchangeWithJoin(plan)) { - return FallbackInfo.DO_NOT_FALLBACK() - } - val transitionCost = countTransitionCost(plan) val fallbackTransitionCost = if (isAdaptiveContext) { countStageFallbackTransitionCost(plan)