Skip to content

Commit

Permalink
refactor reordering join tables
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 15, 2024
1 parent fc7f9cd commit ff51307
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 200 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public static long build(
if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) {
joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal();
} else {
joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal();
joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType(), false).ordinal();
}

return nativeBuild(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
Expand Down Expand Up @@ -382,4 +383,28 @@ object CHBackendSettings extends BackendSettingsApi with Logging {

override def supportCartesianProductExec(): Boolean = true

override def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
t =>
if (super.supportHashBuildJoinTypeOnLeft(t)) {
true
} else {
t match {
case LeftOuter => true
case _ => false
}
}
}

override def supportHashBuildJoinTypeOnRight: JoinType => Boolean = {
t =>
if (super.supportHashBuildJoinTypeOnRight(t)) {
true
} else {
t match {
case RightOuter => true
case _ => false
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, ReorderJoinTablesRule, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
Expand Down Expand Up @@ -605,7 +605,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] =
List(spark => RewriteSortMergeJoinToHashJoinRule(spark), spark => ReorderJoinTablesRule(spark))
List(spark => RewriteSortMergeJoinToHashJoinRule(spark))

override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = {
List()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.execution

import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.utils._
import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy}

import org.apache.spark.{broadcast, SparkContext}
Expand All @@ -40,30 +41,6 @@ object JoinTypeTransform {
joinType
}
}

def toSubstraitType(joinType: JoinType, buildSide: BuildSide): JoinRel.JoinType = {
joinType match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
case FullOuter =>
JoinRel.JoinType.JOIN_TYPE_OUTER
case LeftOuter =>
JoinRel.JoinType.JOIN_TYPE_LEFT
case RightOuter if (buildSide == BuildLeft) =>
// The tables order will be reversed in HashJoinLikeExecTransformer
JoinRel.JoinType.JOIN_TYPE_LEFT
case RightOuter if (buildSide == BuildRight) =>
// This the case rewritten in ReorderJoinLeftRightRule
JoinRel.JoinType.JOIN_TYPE_RIGHT
case LeftSemi | ExistenceJoin(_) =>
JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case LeftAnti =>
JoinRel.JoinType.JOIN_TYPE_ANTI
case _ =>
// TODO: Support cross join with Cross Rel
JoinRel.JoinType.UNRECOGNIZED
}
}
}

case class CHShuffledHashJoinExecTransformer(
Expand Down Expand Up @@ -102,8 +79,10 @@ case class CHShuffledHashJoinExecTransformer(
super.doValidateInternal()
}
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType, buildSide)
override protected lazy val substraitJoinType: JoinRel.JoinType = joinType match {
case ExistenceJoin(_) => JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case _ => SubstraitUtil.toSubstrait(joinType, needSwitchChildren)
}
}

case class CHBroadcastBuildSideRDD(
Expand Down Expand Up @@ -210,6 +189,8 @@ case class CHBroadcastHashJoinExecTransformer(
// We don't have left any join in substrait, so use left semi join instead.
// and isExistenceJoin is set to true to indicate that it is an existence join.
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType, buildSide)
override protected lazy val substraitJoinType: JoinRel.JoinType = joinType match {
case ExistenceJoin(_) => JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case _ => SubstraitUtil.toSubstrait(joinType, false)
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport {
}
}

protected lazy val substraitJoinType: JoinRel.JoinType = SubstraitUtil.toSubstrait(joinType)
protected lazy val substraitJoinType: JoinRel.JoinType =
SubstraitUtil.toSubstrait(joinType, needSwitchChildren)
override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genHashJoinTransformerMetricsUpdater(metrics)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,13 @@ object OffloadJoin {
case Some(join: Join) =>
val leftSize = join.left.stats.sizeInBytes
val rightSize = join.right.stats.sizeInBytes
if (rightSize <= leftSize) BuildRight else BuildLeft
val leftRowCount = join.left.stats.rowCount
val rightRowCount = join.right.stats.rowCount
if (rightSize == leftSize && (rightRowCount.isDefined && leftRowCount.isDefined)) {
if (rightRowCount.get <= leftRowCount.get) BuildRight
else BuildLeft
} else if (rightSize <= leftSize) BuildRight
else BuildLeft
// Only the ShuffledHashJoinExec generated directly in some spark tests is not link
// logical plan, such as OuterJoinSuite.
case _ => shj.buildSide
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,38 @@ import io.substrait.proto.{CrossRel, JoinRel}
import scala.collection.JavaConverters._

object SubstraitUtil {
def toSubstrait(sparkJoin: JoinType): JoinRel.JoinType = sparkJoin match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
case FullOuter =>
JoinRel.JoinType.JOIN_TYPE_OUTER
case LeftOuter | RightOuter =>
// The right side is required to be used for building hash table in Substrait plan.
// Therefore, for RightOuter Join, the left and right relations are exchanged and the
// join type is reverted.
JoinRel.JoinType.JOIN_TYPE_LEFT
case LeftSemi =>
JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case LeftAnti =>
JoinRel.JoinType.JOIN_TYPE_ANTI
case _ =>
// TODO: Support existence join
JoinRel.JoinType.UNRECOGNIZED
}
def toSubstrait(sparkJoin: JoinType, needSwitchChildren: Boolean): JoinRel.JoinType =
sparkJoin match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
case FullOuter =>
JoinRel.JoinType.JOIN_TYPE_OUTER
case LeftOuter =>
if (needSwitchChildren) {
JoinRel.JoinType.JOIN_TYPE_RIGHT
} else {
JoinRel.JoinType.JOIN_TYPE_LEFT
}
case RightOuter =>
if (needSwitchChildren) {
JoinRel.JoinType.JOIN_TYPE_LEFT
} else {
JoinRel.JoinType.JOIN_TYPE_RIGHT
}
case LeftSemi =>
if (needSwitchChildren) {
throw new IllegalArgumentException("LeftSemi join should not switch children")
}
JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case LeftAnti =>
if (needSwitchChildren) {
throw new IllegalArgumentException("LeftSemi join should not switch children")
}
JoinRel.JoinType.JOIN_TYPE_ANTI
case _ =>
// TODO: Support cross join with Cross Rel
JoinRel.JoinType.UNRECOGNIZED
}

def toCrossRelSubstrait(sparkJoin: JoinType): CrossRel.JoinType = sparkJoin match {
case _: InnerLike =>
Expand Down

0 comments on commit ff51307

Please sign in to comment.