From edbc020c663a2de501c389725d777304ca39816a Mon Sep 17 00:00:00 2001 From: zml1206 Date: Tue, 11 Jun 2024 22:08:42 +0800 Subject: [PATCH] [CORE] Optimize JoinSelectionOverrides --- .../execution/JoinExecTransformer.scala | 13 +- .../gluten/extension/StrategyOverrides.scala | 174 +++++------------- .../columnar/OffloadSingleNode.scala | 130 ++++++------- .../columnar/TransformHintRule.scala | 47 +---- 4 files changed, 137 insertions(+), 227 deletions(-) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala index cd22c578594c..9f4a914261e2 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala @@ -27,12 +27,13 @@ import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{ExpandOutputPartitioningShim, ExplainUtils, SparkPlan} -import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashedRelationBroadcastMode, HashJoin} +import org.apache.spark.sql.execution.{ExpandOutputPartitioningShim, ExplainUtils, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashedRelationBroadcastMode, HashJoin, ShuffledJoin} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -420,3 +421,11 @@ abstract class BroadcastHashJoinExecTransformerBase( (1, if (isNullAwareAntiJoin) 1 else 0, buildHashTableId) } } + +case class ShuffledHashJoinExecTemp(child: ShuffledJoin, buildSide: BuildSide) + extends UnaryExecNode { + override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException() + override def output: Seq[Attribute] = child.output + override protected def withNewChildInternal(newChild: SparkPlan): ShuffledHashJoinExecTemp = + copy(child = newChild.asInstanceOf[ShuffledJoin]) +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/StrategyOverrides.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/StrategyOverrides.scala index f2f786259393..3de929932fc1 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/StrategyOverrides.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/StrategyOverrides.scala @@ -18,19 +18,18 @@ package org.apache.gluten.extension import org.apache.gluten.{GlutenConfig, GlutenSparkExtensionsInjector} import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.execution.ShuffledHashJoinExecTemp import org.apache.gluten.extension.columnar.TRANSFORM_UNSUPPORTED import org.apache.gluten.extension.columnar.TransformHints.TAG import org.apache.gluten.utils.LogicalPlanSelector import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy} import org.apache.spark.sql.catalyst.SQLConfHelper -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{joins, JoinSelectionShim, SparkPlan} -import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, LogicalQueryStage} -import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.{JoinSelectionShim, SparkPlan} +import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec} object StrategyOverrides extends GlutenSparkExtensionsInjector { override def inject(extensions: SparkSessionExtensions): Unit = { @@ -43,112 +42,6 @@ case class JoinSelectionOverrides(session: SparkSession) with JoinSelectionHelper with SQLConfHelper { - private def isBroadcastStage(plan: LogicalPlan): Boolean = plan match { - case LogicalQueryStage(_, _: BroadcastQueryStageExec) => true - case _ => false - } - - def extractEqualJoinKeyCondition( - joinType: JoinType, - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - condition: Option[Expression], - left: LogicalPlan, - right: LogicalPlan, - hint: JoinHint, - forceShuffledHashJoin: Boolean): Seq[SparkPlan] = { - if (isBroadcastStage(left) || isBroadcastStage(right)) { - val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight - Seq( - BroadcastHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - planLater(left), - planLater(right))) - } else { - // Generate BHJ here, avoid to do match in `JoinSelection` again. - val isHintEmpty = hint.leftHint.isEmpty && hint.rightHint.isEmpty - val buildSide = getBroadcastBuildSide(left, right, joinType, hint, !isHintEmpty, conf) - if (buildSide.isDefined) { - return Seq( - joins.BroadcastHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide.get, - condition, - planLater(left), - planLater(right))) - } - - if ( - forceShuffledHashJoin && - !BackendsApiManager.getSparkPlanExecApiInstance.joinFallback( - joinType, - left.outputSet, - right.outputSet, - condition) && - !left.getTagValue(TAG).isDefined && - !right.getTagValue(TAG).isDefined - ) { - // Force use of ShuffledHashJoin in preference to SortMergeJoin. With no respect to - // conf setting "spark.sql.join.preferSortMergeJoin". - val (leftBuildable, rightBuildable) = - if (BackendsApiManager.getSettings.utilizeShuffledHashJoinHint()) { - // Currently, ClickHouse backend can not support AQE, so it needs to use join hint - // to decide the build side, after supporting AQE, will remove this. - val leftHintEnabled = hintToShuffleHashJoinLeft(hint) - val rightHintEnabled = hintToShuffleHashJoinRight(hint) - val leftHintMergeEnabled = hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE)) - val rightHintMergeEnabled = hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE)) - if (leftHintEnabled || rightHintEnabled) { - (leftHintEnabled, rightHintEnabled) - } else if (leftHintMergeEnabled || rightHintMergeEnabled) { - // hack: when set SHUFFLE_MERGE hint, it means that - // it don't use this side as the build side - (!leftHintMergeEnabled, !rightHintMergeEnabled) - } else { - ( - BackendsApiManager.getSettings.supportHashBuildJoinTypeOnLeft(joinType), - BackendsApiManager.getSettings.supportHashBuildJoinTypeOnRight(joinType)) - } - } else { - (canBuildShuffledHashJoinLeft(joinType), canBuildShuffledHashJoinRight(joinType)) - } - - if (!leftBuildable && !rightBuildable) { - return Nil - } - val buildSide = if (!leftBuildable) { - BuildRight - } else if (!rightBuildable) { - BuildLeft - } else { - getSmallerSide(left, right) - } - - return Option(buildSide) - .map { - buildSide => - Seq( - joins.ShuffledHashJoinExec( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - planLater(left), - planLater(right))) - } - .getOrElse(Nil) - } - Nil - } - } - def existsMultiJoins(plan: LogicalPlan, count: Int = 0): Boolean = { plan match { case plan: Join => @@ -157,7 +50,7 @@ case class JoinSelectionOverrides(session: SparkSession) case plan: Project => if ((count + 1) >= GlutenConfig.getConf.logicalJoinOptimizationThrottle) return true plan.children.exists(existsMultiJoins(_, count + 1)) - case other => false + case _ => false } } @@ -166,6 +59,11 @@ case class JoinSelectionOverrides(session: SparkSession) plan } + def tagNotTransformable(plan: ShuffledJoin, reason: String): ShuffledJoin = { + plan.setTagValue(TAG, TRANSFORM_UNSUPPORTED(Some(reason))) + plan + } + def tagNotTransformableRecursive(plan: LogicalPlan, reason: String): LogicalPlan = { tagNotTransformable( plan.withNewChildren(plan.children.map(tagNotTransformableRecursive(_, reason))), @@ -179,6 +77,26 @@ case class JoinSelectionOverrides(session: SparkSession) }.size > 0 } + def genShuffledHashJoinExecTemp( + joinType: JoinType, + left: LogicalPlan, + right: LogicalPlan, + join: ShuffledJoin): ShuffledHashJoinExecTemp = { + val leftBuildable = BackendsApiManager.getSettings + .supportHashBuildJoinTypeOnLeft(joinType) + val rightBuildable = BackendsApiManager.getSettings + .supportHashBuildJoinTypeOnRight(joinType) + val buildSide = if (!leftBuildable) { + BuildRight + } else if (!rightBuildable) { + BuildLeft + } else { + getSmallerSide(left, right) + } + val child = tagNotTransformable(join, "child of ShuffledHashJoinExecTemp") + ShuffledHashJoinExecTemp(child, buildSide) + } + override def apply(plan: LogicalPlan): Seq[SparkPlan] = LogicalPlanSelector.maybeNil(session, plan) { // Ignore forceShuffledHashJoin if exist multi continuous joins @@ -189,24 +107,30 @@ case class JoinSelectionOverrides(session: SparkSession) tagNotTransformableRecursive(plan, "exist multi continuous joins") } plan match { - // If the build side of BHJ is already decided by AQE, we need to keep the build side. - case JoinSelectionShim.ExtractEquiJoinKeysShim( + case j @ JoinSelectionShim.ExtractEquiJoinKeysShim( joinType, - leftKeys, - rightKeys, + _, + _, condition, left, right, - hint) => - extractEqualJoinKeyCondition( - joinType, - leftKeys, - rightKeys, - condition, - left, - right, - hint, - GlutenConfig.getConf.forceShuffledHashJoin) + _) => + val originalJoinExec = session.sessionState.planner.JoinSelection.apply(j) + originalJoinExec(0) match { + case shj: ShuffledHashJoinExec => + Seq(genShuffledHashJoinExecTemp(joinType, left, right, shj)) + case smj: SortMergeJoinExec + if GlutenConfig.getConf.forceShuffledHashJoin && + !BackendsApiManager.getSparkPlanExecApiInstance.joinFallback( + joinType, + left.outputSet, + right.outputSet, + condition) && + !left.getTagValue(TAG).isDefined && + !right.getTagValue(TAG).isDefined => + Seq(genShuffledHashJoinExecTemp(joinType, left, right, smj)) + case _ => originalJoinExec + } case _ => Nil } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index 6e4d37f633eb..f0e042b22015 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -119,69 +119,75 @@ case class OffloadExchange() extends OffloadSingleNode with LogLevelUtil { // Join transformation. case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil { - override def offload(plan: SparkPlan): SparkPlan = { - if (TransformHints.isNotTransformable(plan)) { + private def dropPartialSort(plan: SparkPlan): SparkPlan = plan match { + case sort: SortExecTransformer if !sort.global => + sort.child + case sort: SortExec if !sort.global => + sort.child + case _ => plan + } + + override def offload(plan: SparkPlan): SparkPlan = plan match { + case ShuffledHashJoinExecTemp(join, _) if TransformHints.isNotTransformable(plan) => + logDebug(s"Columnar Processing for ${join.getClass} is under row guard.") + join + case p if TransformHints.isNotTransformable(p) => logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.") - return plan - } - plan match { - case plan: ShuffledHashJoinExec => - val left = plan.left - val right = plan.right - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - BackendsApiManager.getSparkPlanExecApiInstance - .genShuffledHashJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - TransformHints.getShuffleHashJoinBuildSide(plan), - plan.condition, - left, - right, - plan.isSkewJoin) - case plan: SortMergeJoinExec => - val left = plan.left - val right = plan.right - logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - BackendsApiManager.getSparkPlanExecApiInstance - .genSortMergeJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.condition, - left, - right, - plan.isSkewJoin) - case plan: BroadcastHashJoinExec => - val left = plan.left - val right = plan.right - BackendsApiManager.getSparkPlanExecApiInstance - .genBroadcastHashJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right, - isNullAwareAntiJoin = plan.isNullAwareAntiJoin) - case plan: CartesianProductExec => - val left = plan.left - val right = plan.right - BackendsApiManager.getSparkPlanExecApiInstance - .genCartesianProductExecTransformer(left, right, plan.condition) - case plan: BroadcastNestedLoopJoinExec => - val left = plan.left - val right = plan.right - BackendsApiManager.getSparkPlanExecApiInstance - .genBroadcastNestedLoopJoinExecTransformer( - left, - right, - plan.buildSide, - plan.joinType, - plan.condition) - case other => other - } + p + case ShuffledHashJoinExecTemp(join, buildSide) => + logDebug(s"Columnar Processing for ${join.getClass} is currently supported.") + BackendsApiManager.getSparkPlanExecApiInstance + .genShuffledHashJoinExecTransformer( + join.leftKeys, + join.rightKeys, + join.joinType, + buildSide, + join.condition, + dropPartialSort(join.left), + dropPartialSort(join.right), + join.isSkewJoin) + case plan: SortMergeJoinExec => + val left = plan.left + val right = plan.right + logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") + BackendsApiManager.getSparkPlanExecApiInstance + .genSortMergeJoinExecTransformer( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.condition, + left, + right, + plan.isSkewJoin) + case plan: BroadcastHashJoinExec => + val left = plan.left + val right = plan.right + BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastHashJoinExecTransformer( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + isNullAwareAntiJoin = plan.isNullAwareAntiJoin) + case plan: CartesianProductExec => + val left = plan.left + val right = plan.right + BackendsApiManager.getSparkPlanExecApiInstance + .genCartesianProductExecTransformer(left, right, plan.condition) + case plan: BroadcastNestedLoopJoinExec => + val left = plan.left + val right = plan.right + BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastNestedLoopJoinExecTransformer( + left, + right, + plan.buildSide, + plan.joinType, + plan.condition) + case other => other } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala index ca35c74f6892..efb37d8d2ca6 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/TransformHintRule.scala @@ -29,8 +29,6 @@ import org.apache.spark.api.python.EvalPythonExecTransformer import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.AttributeReference -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.execution._ @@ -160,33 +158,6 @@ object TransformHints { tag(plan, newTag) } } - - def getShuffleHashJoinBuildSide(shj: ShuffledHashJoinExec): BuildSide = { - if (BackendsApiManager.getSettings.utilizeShuffledHashJoinHint()) { - shj.buildSide - } else { - val leftBuildable = BackendsApiManager.getSettings - .supportHashBuildJoinTypeOnLeft(shj.joinType) - val rightBuildable = BackendsApiManager.getSettings - .supportHashBuildJoinTypeOnRight(shj.joinType) - - if (!leftBuildable) { - BuildRight - } else if (!rightBuildable) { - BuildLeft - } else { - shj.logicalLink match { - case Some(join: Join) => - val leftSize = join.left.stats.sizeInBytes - val rightSize = join.right.stats.sizeInBytes - if (rightSize <= leftSize) BuildRight else BuildLeft - // Only the ShuffledHashJoinExec generated directly in some spark tests is not link - // logical plan, such as OuterJoinSuite. - case _ => shj.buildSide - } - } - } - } } case class FallbackOnANSIMode(session: SparkSession) extends Rule[SparkPlan] { @@ -415,17 +386,17 @@ case class AddTransformHintRule() extends Rule[SparkPlan] { case plan: ShuffleExchangeExec => val transformer = ColumnarShuffleExchangeExec(plan, plan.child, plan.child.output) transformer.doValidate().tagOnFallback(plan) - case plan: ShuffledHashJoinExec => + case plan @ ShuffledHashJoinExecTemp(join, buildSide) => val transformer = BackendsApiManager.getSparkPlanExecApiInstance .genShuffledHashJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - TransformHints.getShuffleHashJoinBuildSide(plan), - plan.condition, - plan.left, - plan.right, - plan.isSkewJoin) + join.leftKeys, + join.rightKeys, + join.joinType, + buildSide, + join.condition, + join.left, + join.right, + join.isSkewJoin) transformer.doValidate().tagOnFallback(plan) case plan: BroadcastExchangeExec => val transformer = ColumnarBroadcastExchangeExec(plan.mode, plan.child)