Skip to content

Commit

Permalink
[VL] Make conf option s.g.s.c.shuffledHashJoin.optimizeBuildSide wo…
Browse files Browse the repository at this point in the history
…rk correctly with option `s.g.s.c.forceShuffledHashJoin` (apache#7186)
  • Loading branch information
zhztheplayer authored and shamirchen committed Oct 14, 2024
1 parent f12f7bc commit 3df7919
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] {
plan.leftKeys,
plan.rightKeys,
plan.joinType,
OffloadJoin.getBuildSide(plan),
OffloadJoin.getShjBuildSide(plan),
plan.condition,
plan.left,
plan.right,
Expand Down Expand Up @@ -443,13 +443,13 @@ case class AddFallbackTagRule() extends Rule[SparkPlan] {
offset)
transformer.doValidate().tagOnFallback(plan)
case plan: SampleExec =>
val transformer = BackendsApiManager.getSparkPlanExecApiInstance.genSampleExecTransformer(
plan.lowerBound,
plan.upperBound,
plan.withReplacement,
plan.seed,
plan.child
)
val transformer =
BackendsApiManager.getSparkPlanExecApiInstance.genSampleExecTransformer(
plan.lowerBound,
plan.upperBound,
plan.withReplacement,
plan.seed,
plan.child)
transformer.doValidate().tagOnFallback(plan)
case _ =>
// Currently we assume a plan to be offload-able by default.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ case class OffloadExchange() extends OffloadSingleNode with LogLevelUtil {

// Join transformation.
case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil {

override def offload(plan: SparkPlan): SparkPlan = {
if (FallbackTags.nonEmpty(plan)) {
logDebug(s"Columnar Processing for ${plan.getClass} is under row guard.")
Expand All @@ -134,7 +133,7 @@ case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil {
plan.leftKeys,
plan.rightKeys,
plan.joinType,
OffloadJoin.getBuildSide(plan),
OffloadJoin.getShjBuildSide(plan),
plan.condition,
left,
right,
Expand Down Expand Up @@ -186,37 +185,53 @@ case class OffloadJoin() extends OffloadSingleNode with LogLevelUtil {
}

object OffloadJoin {

def getBuildSide(shj: ShuffledHashJoinExec): BuildSide = {
def getShjBuildSide(shj: ShuffledHashJoinExec): BuildSide = {
val leftBuildable =
BackendsApiManager.getSettings.supportHashBuildJoinTypeOnLeft(shj.joinType)
val rightBuildable =
BackendsApiManager.getSettings.supportHashBuildJoinTypeOnRight(shj.joinType)

assert(leftBuildable || rightBuildable)

if (!leftBuildable) {
return BuildRight
}
if (!rightBuildable) {
return BuildLeft
}

// Both left and right are buildable. Find out the better one.
if (!GlutenConfig.getConf.shuffledHashJoinOptimizeBuildSide) {
// User disabled build side re-optimization. Return original build side from vanilla Spark.
return shj.buildSide
}
shj.logicalLink match {
case Some(join: Join) =>
val leftSize = join.left.stats.sizeInBytes
val rightSize = join.right.stats.sizeInBytes
val leftRowCount = join.left.stats.rowCount
val rightRowCount = join.right.stats.rowCount
if (rightSize == leftSize && rightRowCount.isDefined && leftRowCount.isDefined) {
if (rightRowCount.get <= leftRowCount.get) BuildRight
else BuildLeft
} else if (rightSize <= leftSize) BuildRight
else BuildLeft
// Only the ShuffledHashJoinExec generated directly in some spark tests is not link
// logical plan, such as OuterJoinSuite.
case _ => shj.buildSide
shj.logicalLink
.flatMap {
case join: Join => Some(getOptimalBuildSide(join))
case _ => None
}
.getOrElse {
// Some shj operators generated in certain Spark tests such as OuterJoinSuite,
// could possibly have no logical link set.
shj.buildSide
}
}

def getOptimalBuildSide(join: Join): BuildSide = {
val leftSize = join.left.stats.sizeInBytes
val rightSize = join.right.stats.sizeInBytes
val leftRowCount = join.left.stats.rowCount
val rightRowCount = join.right.stats.rowCount
if (leftSize == rightSize && rightRowCount.isDefined && leftRowCount.isDefined) {
if (rightRowCount.get <= leftRowCount.get) {
return BuildRight
}
return BuildLeft
}
if (rightSize <= leftSize) {
return BuildRight
}
BuildLeft
}
}

Expand Down Expand Up @@ -332,8 +347,7 @@ object OffloadOthers {
plan.partitionColumns,
plan.bucketSpec,
plan.options,
plan.staticPartitions
)
plan.staticPartitions)
case plan: SortExec =>
val child = plan.child
logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,43 @@
package org.apache.gluten.extension.columnar.rewrite

import org.apache.gluten.GlutenConfig
import org.apache.gluten.extension.columnar.OffloadJoin

import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper}
import org.apache.spark.sql.catalyst.plans.JoinType
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec}

/**
* If force ShuffledHashJoin, convert [[SortMergeJoinExec]] to [[ShuffledHashJoinExec]]. There is no
* need to select a smaller table as buildSide here, it will be reselected when offloading.
*/
/** If force ShuffledHashJoin, convert [[SortMergeJoinExec]] to [[ShuffledHashJoinExec]]. */
object RewriteJoin extends RewriteSingleNode with JoinSelectionHelper {

private def getBuildSide(joinType: JoinType): Option[BuildSide] = {
val leftBuildable = canBuildShuffledHashJoinLeft(joinType)
val rightBuildable = canBuildShuffledHashJoinRight(joinType)
if (rightBuildable) {
Some(BuildRight)
} else if (leftBuildable) {
Some(BuildLeft)
} else {
None
private def getSmjBuildSide(join: SortMergeJoinExec): Option[BuildSide] = {
val leftBuildable = canBuildShuffledHashJoinLeft(join.joinType)
val rightBuildable = canBuildShuffledHashJoinRight(join.joinType)
if (!leftBuildable && !rightBuildable) {
return None
}
if (!leftBuildable) {
return Some(BuildRight)
}
if (!rightBuildable) {
return Some(BuildLeft)
}
val side = join.logicalLink
.flatMap {
case join: Join => Some(OffloadJoin.getOptimalBuildSide(join))
case _ => None
}
.getOrElse {
// If smj has no logical link, or its logical link is not a join,
// then we always choose left as build side.
BuildLeft
}
Some(side)
}

override def rewrite(plan: SparkPlan): SparkPlan = plan match {
case smj: SortMergeJoinExec if GlutenConfig.getConf.forceShuffledHashJoin =>
getBuildSide(smj.joinType) match {
getSmjBuildSide(smj) match {
case Some(buildSide) =>
ShuffledHashJoinExec(
smj.leftKeys,
Expand All @@ -53,8 +63,7 @@ object RewriteJoin extends RewriteSingleNode with JoinSelectionHelper {
smj.condition,
smj.left,
smj.right,
smj.isSkewJoin
)
smj.isSkewJoin)
case _ => plan
}
case _ => plan
Expand Down

0 comments on commit 3df7919

Please sign in to comment.