Skip to content

Commit

Permalink
refactor reordering join tables
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 15, 2024
1 parent fc7f9cd commit c5f61ad
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 166 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.vectorized;

import org.apache.gluten.execution.BroadCastHashJoinContext;
import org.apache.gluten.execution.JoinTypeTransform;
import org.apache.gluten.expression.ConverterUtils;
import org.apache.gluten.expression.ConverterUtils$;
import org.apache.gluten.substrait.type.TypeNode;
Expand Down Expand Up @@ -80,7 +81,8 @@ public static long build(
if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) {
joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal();
} else {
joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal();
joinType =
JoinTypeTransform.toSubstraitJoinType(broadCastContext.joinType(), false).ordinal();
}

return nativeBuild(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
Expand Down Expand Up @@ -382,4 +383,28 @@ object CHBackendSettings extends BackendSettingsApi with Logging {

override def supportCartesianProductExec(): Boolean = true

override def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = {
t =>
if (super.supportHashBuildJoinTypeOnLeft(t)) {
true
} else {
t match {
case LeftOuter => true
case _ => false
}
}
}

override def supportHashBuildJoinTypeOnRight: JoinType => Boolean = {
t =>
if (super.supportHashBuildJoinTypeOnRight(t)) {
true
} else {
t match {
case RightOuter => true
case _ => false
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, ReorderJoinTablesRule, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
Expand Down Expand Up @@ -605,7 +605,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] =
List(spark => RewriteSortMergeJoinToHashJoinRule(spark), spark => ReorderJoinTablesRule(spark))
List(spark => RewriteSortMergeJoinToHashJoinRule(spark))

override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = {
List()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.gluten.execution

import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.utils._
import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy}

import org.apache.spark.{broadcast, SparkContext}
Expand All @@ -41,29 +42,39 @@ object JoinTypeTransform {
}
}

def toSubstraitType(joinType: JoinType, buildSide: BuildSide): JoinRel.JoinType = {
joinType match {
def toSubstraitJoinType(sparkJoin: JoinType, needSwitchChildren: Boolean): JoinRel.JoinType =
sparkJoin match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
case FullOuter =>
JoinRel.JoinType.JOIN_TYPE_OUTER
case LeftOuter =>
JoinRel.JoinType.JOIN_TYPE_LEFT
case RightOuter if (buildSide == BuildLeft) =>
// The tables order will be reversed in HashJoinLikeExecTransformer
JoinRel.JoinType.JOIN_TYPE_LEFT
case RightOuter if (buildSide == BuildRight) =>
// This the case rewritten in ReorderJoinLeftRightRule
JoinRel.JoinType.JOIN_TYPE_RIGHT
if (needSwitchChildren) {
JoinRel.JoinType.JOIN_TYPE_RIGHT
} else {
JoinRel.JoinType.JOIN_TYPE_LEFT
}
case RightOuter =>
if (needSwitchChildren) {
JoinRel.JoinType.JOIN_TYPE_LEFT
} else {
JoinRel.JoinType.JOIN_TYPE_RIGHT
}
case LeftSemi | ExistenceJoin(_) =>
if (needSwitchChildren) {
throw new IllegalArgumentException("LeftSemi join should not switch children")
}
JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case LeftAnti =>
if (needSwitchChildren) {
throw new IllegalArgumentException("LeftAnti join should not switch children")
}
JoinRel.JoinType.JOIN_TYPE_ANTI
case _ =>
// TODO: Support cross join with Cross Rel
JoinRel.JoinType.UNRECOGNIZED
}
}

}

case class CHShuffledHashJoinExecTransformer(
Expand Down Expand Up @@ -103,7 +114,7 @@ case class CHShuffledHashJoinExecTransformer(
}
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType, buildSide)
JoinTypeTransform.toSubstraitJoinType(joinType, needSwitchChildren)
}

case class CHBroadcastBuildSideRDD(
Expand Down Expand Up @@ -211,5 +222,5 @@ case class CHBroadcastHashJoinExecTransformer(
// and isExistenceJoin is set to true to indicate that it is an existence join.
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType, buildSide)
JoinTypeTransform.toSubstraitJoinType(joinType, false)
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport {
}
}

protected lazy val substraitJoinType: JoinRel.JoinType = SubstraitUtil.toSubstrait(joinType)
protected lazy val substraitJoinType: JoinRel.JoinType =
SubstraitUtil.toSubstrait(joinType)
override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genHashJoinTransformerMetricsUpdater(metrics)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,13 @@ object OffloadJoin {
case Some(join: Join) =>
val leftSize = join.left.stats.sizeInBytes
val rightSize = join.right.stats.sizeInBytes
if (rightSize <= leftSize) BuildRight else BuildLeft
val leftRowCount = join.left.stats.rowCount
val rightRowCount = join.right.stats.rowCount
if (rightSize == leftSize && (rightRowCount.isDefined && leftRowCount.isDefined)) {
if (rightRowCount.get <= leftRowCount.get) BuildRight
else BuildLeft
} else if (rightSize <= leftSize) BuildRight
else BuildLeft
// Only the ShuffledHashJoinExec generated directly in some spark tests is not link
// logical plan, such as OuterJoinSuite.
case _ => shj.buildSide
Expand Down

0 comments on commit c5f61ad

Please sign in to comment.