Skip to content

Commit

Permalink
[CORE] Optimize JoinSelectionOverrides
Browse files Browse the repository at this point in the history
  • Loading branch information
zml1206 committed Jun 11, 2024
1 parent d3ccd4a commit edbc020
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 227 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode
import org.apache.gluten.utils.SubstraitUtil

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{ExpandOutputPartitioningShim, ExplainUtils, SparkPlan}
import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashedRelationBroadcastMode, HashJoin}
import org.apache.spark.sql.execution.{ExpandOutputPartitioningShim, ExplainUtils, SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.joins.{BaseJoinExec, HashedRelationBroadcastMode, HashJoin, ShuffledJoin}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -420,3 +421,11 @@ abstract class BroadcastHashJoinExecTransformerBase(
(1, if (isNullAwareAntiJoin) 1 else 0, buildHashTableId)
}
}

case class ShuffledHashJoinExecTemp(child: ShuffledJoin, buildSide: BuildSide)
extends UnaryExecNode {
override protected def doExecute(): RDD[InternalRow] = throw new UnsupportedOperationException()
override def output: Seq[Attribute] = child.output
override protected def withNewChildInternal(newChild: SparkPlan): ShuffledHashJoinExecTemp =
copy(child = newChild.asInstanceOf[ShuffledJoin])
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@ package org.apache.gluten.extension

import org.apache.gluten.{GlutenConfig, GlutenSparkExtensionsInjector}
import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.execution.ShuffledHashJoinExecTemp
import org.apache.gluten.extension.columnar.TRANSFORM_UNSUPPORTED
import org.apache.gluten.extension.columnar.TransformHints.TAG
import org.apache.gluten.utils.LogicalPlanSelector

import org.apache.spark.sql.{SparkSession, SparkSessionExtensions, Strategy}
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{joins, JoinSelectionShim, SparkPlan}
import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, LogicalQueryStage}
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.execution.{JoinSelectionShim, SparkPlan}
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, ShuffledJoin, SortMergeJoinExec}

object StrategyOverrides extends GlutenSparkExtensionsInjector {
override def inject(extensions: SparkSessionExtensions): Unit = {
Expand All @@ -43,112 +42,6 @@ case class JoinSelectionOverrides(session: SparkSession)
with JoinSelectionHelper
with SQLConfHelper {

private def isBroadcastStage(plan: LogicalPlan): Boolean = plan match {
case LogicalQueryStage(_, _: BroadcastQueryStageExec) => true
case _ => false
}

def extractEqualJoinKeyCondition(
joinType: JoinType,
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
condition: Option[Expression],
left: LogicalPlan,
right: LogicalPlan,
hint: JoinHint,
forceShuffledHashJoin: Boolean): Seq[SparkPlan] = {
if (isBroadcastStage(left) || isBroadcastStage(right)) {
val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight
Seq(
BroadcastHashJoinExec(
leftKeys,
rightKeys,
joinType,
buildSide,
condition,
planLater(left),
planLater(right)))
} else {
// Generate BHJ here, avoid to do match in `JoinSelection` again.
val isHintEmpty = hint.leftHint.isEmpty && hint.rightHint.isEmpty
val buildSide = getBroadcastBuildSide(left, right, joinType, hint, !isHintEmpty, conf)
if (buildSide.isDefined) {
return Seq(
joins.BroadcastHashJoinExec(
leftKeys,
rightKeys,
joinType,
buildSide.get,
condition,
planLater(left),
planLater(right)))
}

if (
forceShuffledHashJoin &&
!BackendsApiManager.getSparkPlanExecApiInstance.joinFallback(
joinType,
left.outputSet,
right.outputSet,
condition) &&
!left.getTagValue(TAG).isDefined &&
!right.getTagValue(TAG).isDefined
) {
// Force use of ShuffledHashJoin in preference to SortMergeJoin. With no respect to
// conf setting "spark.sql.join.preferSortMergeJoin".
val (leftBuildable, rightBuildable) =
if (BackendsApiManager.getSettings.utilizeShuffledHashJoinHint()) {
// Currently, ClickHouse backend can not support AQE, so it needs to use join hint
// to decide the build side, after supporting AQE, will remove this.
val leftHintEnabled = hintToShuffleHashJoinLeft(hint)
val rightHintEnabled = hintToShuffleHashJoinRight(hint)
val leftHintMergeEnabled = hint.leftHint.exists(_.strategy.contains(SHUFFLE_MERGE))
val rightHintMergeEnabled = hint.rightHint.exists(_.strategy.contains(SHUFFLE_MERGE))
if (leftHintEnabled || rightHintEnabled) {
(leftHintEnabled, rightHintEnabled)
} else if (leftHintMergeEnabled || rightHintMergeEnabled) {
// hack: when set SHUFFLE_MERGE hint, it means that
// it don't use this side as the build side
(!leftHintMergeEnabled, !rightHintMergeEnabled)
} else {
(
BackendsApiManager.getSettings.supportHashBuildJoinTypeOnLeft(joinType),
BackendsApiManager.getSettings.supportHashBuildJoinTypeOnRight(joinType))
}
} else {
(canBuildShuffledHashJoinLeft(joinType), canBuildShuffledHashJoinRight(joinType))
}

if (!leftBuildable && !rightBuildable) {
return Nil
}
val buildSide = if (!leftBuildable) {
BuildRight
} else if (!rightBuildable) {
BuildLeft
} else {
getSmallerSide(left, right)
}

return Option(buildSide)
.map {
buildSide =>
Seq(
joins.ShuffledHashJoinExec(
leftKeys,
rightKeys,
joinType,
buildSide,
condition,
planLater(left),
planLater(right)))
}
.getOrElse(Nil)
}
Nil
}
}

def existsMultiJoins(plan: LogicalPlan, count: Int = 0): Boolean = {
plan match {
case plan: Join =>
Expand All @@ -157,7 +50,7 @@ case class JoinSelectionOverrides(session: SparkSession)
case plan: Project =>
if ((count + 1) >= GlutenConfig.getConf.logicalJoinOptimizationThrottle) return true
plan.children.exists(existsMultiJoins(_, count + 1))
case other => false
case _ => false
}
}

Expand All @@ -166,6 +59,11 @@ case class JoinSelectionOverrides(session: SparkSession)
plan
}

def tagNotTransformable(plan: ShuffledJoin, reason: String): ShuffledJoin = {
plan.setTagValue(TAG, TRANSFORM_UNSUPPORTED(Some(reason)))
plan
}

def tagNotTransformableRecursive(plan: LogicalPlan, reason: String): LogicalPlan = {
tagNotTransformable(
plan.withNewChildren(plan.children.map(tagNotTransformableRecursive(_, reason))),
Expand All @@ -179,6 +77,26 @@ case class JoinSelectionOverrides(session: SparkSession)
}.size > 0
}

def genShuffledHashJoinExecTemp(
joinType: JoinType,
left: LogicalPlan,
right: LogicalPlan,
join: ShuffledJoin): ShuffledHashJoinExecTemp = {
val leftBuildable = BackendsApiManager.getSettings
.supportHashBuildJoinTypeOnLeft(joinType)
val rightBuildable = BackendsApiManager.getSettings
.supportHashBuildJoinTypeOnRight(joinType)
val buildSide = if (!leftBuildable) {
BuildRight
} else if (!rightBuildable) {
BuildLeft
} else {
getSmallerSide(left, right)
}
val child = tagNotTransformable(join, "child of ShuffledHashJoinExecTemp")
ShuffledHashJoinExecTemp(child, buildSide)
}

override def apply(plan: LogicalPlan): Seq[SparkPlan] =
LogicalPlanSelector.maybeNil(session, plan) {
// Ignore forceShuffledHashJoin if exist multi continuous joins
Expand All @@ -189,24 +107,30 @@ case class JoinSelectionOverrides(session: SparkSession)
tagNotTransformableRecursive(plan, "exist multi continuous joins")
}
plan match {
// If the build side of BHJ is already decided by AQE, we need to keep the build side.
case JoinSelectionShim.ExtractEquiJoinKeysShim(
case j @ JoinSelectionShim.ExtractEquiJoinKeysShim(
joinType,
leftKeys,
rightKeys,
_,
_,
condition,
left,
right,
hint) =>
extractEqualJoinKeyCondition(
joinType,
leftKeys,
rightKeys,
condition,
left,
right,
hint,
GlutenConfig.getConf.forceShuffledHashJoin)
_) =>
val originalJoinExec = session.sessionState.planner.JoinSelection.apply(j)
originalJoinExec(0) match {
case shj: ShuffledHashJoinExec =>
Seq(genShuffledHashJoinExecTemp(joinType, left, right, shj))
case smj: SortMergeJoinExec
if GlutenConfig.getConf.forceShuffledHashJoin &&
!BackendsApiManager.getSparkPlanExecApiInstance.joinFallback(
joinType,
left.outputSet,
right.outputSet,
condition) &&
!left.getTagValue(TAG).isDefined &&
!right.getTagValue(TAG).isDefined =>
Seq(genShuffledHashJoinExecTemp(joinType, left, right, smj))
case _ => originalJoinExec
}
case _ => Nil
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,69 +119,75 @@ case class OffloadExchange() extends OffloadSingleNode with LogLevelUtil {
// Join transformation.
case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil {

override def offload(plan: SparkPlan): SparkPlan = {
if (TransformHints.isNotTransformable(plan)) {
private def dropPartialSort(plan: SparkPlan): SparkPlan = plan match {
case sort: SortExecTransformer if !sort.global =>
sort.child
case sort: SortExec if !sort.global =>
sort.child
case _ => plan
}

override def offload(plan: SparkPlan): SparkPlan = plan match {
case ShuffledHashJoinExecTemp(join, _) if TransformHints.isNotTransformable(plan) =>
logDebug(s"Columnar Processing for ${join.getClass} is under row guard.")
join
case p if TransformHints.isNotTransformable(p) =>
logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
return plan
}
plan match {
case plan: ShuffledHashJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
BackendsApiManager.getSparkPlanExecApiInstance
.genShuffledHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
TransformHints.getShuffleHashJoinBuildSide(plan),
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: SortMergeJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
BackendsApiManager.getSparkPlanExecApiInstance
.genSortMergeJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: BroadcastHashJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
case plan: CartesianProductExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genCartesianProductExecTransformer(left, right, plan.condition)
case plan: BroadcastNestedLoopJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
left,
right,
plan.buildSide,
plan.joinType,
plan.condition)
case other => other
}
p
case ShuffledHashJoinExecTemp(join, buildSide) =>
logDebug(s"Columnar Processing for ${join.getClass} is currently supported.")
BackendsApiManager.getSparkPlanExecApiInstance
.genShuffledHashJoinExecTransformer(
join.leftKeys,
join.rightKeys,
join.joinType,
buildSide,
join.condition,
dropPartialSort(join.left),
dropPartialSort(join.right),
join.isSkewJoin)
case plan: SortMergeJoinExec =>
val left = plan.left
val right = plan.right
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
BackendsApiManager.getSparkPlanExecApiInstance
.genSortMergeJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.condition,
left,
right,
plan.isSkewJoin)
case plan: BroadcastHashJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastHashJoinExecTransformer(
plan.leftKeys,
plan.rightKeys,
plan.joinType,
plan.buildSide,
plan.condition,
left,
right,
isNullAwareAntiJoin = plan.isNullAwareAntiJoin)
case plan: CartesianProductExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genCartesianProductExecTransformer(left, right, plan.condition)
case plan: BroadcastNestedLoopJoinExec =>
val left = plan.left
val right = plan.right
BackendsApiManager.getSparkPlanExecApiInstance
.genBroadcastNestedLoopJoinExecTransformer(
left,
right,
plan.buildSide,
plan.joinType,
plan.condition)
case other => other
}

}
Expand Down
Loading

0 comments on commit edbc020

Please sign in to comment.