From c5f61add2291ebf9a4c5cb954e1de697f431f880 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 15 Aug 2024 11:36:39 +0800 Subject: [PATCH] refactor reordering join tables --- .../gluten/vectorized/StorageJoinBuilder.java | 4 +- .../backendsapi/clickhouse/CHBackend.scala | 25 +++ .../clickhouse/CHSparkPlanExecApi.scala | 4 +- .../execution/CHHashJoinExecTransformer.scala | 35 ++-- .../extension/ReorderJoinTablesRule.scala | 149 ------------------ .../execution/JoinExecTransformer.scala | 3 +- .../columnar/OffloadSingleNode.scala | 8 +- 7 files changed, 62 insertions(+), 166 deletions(-) delete mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java index ae7b89120cd4d..597857291ba54 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java @@ -17,6 +17,7 @@ package org.apache.gluten.vectorized; import org.apache.gluten.execution.BroadCastHashJoinContext; +import org.apache.gluten.execution.JoinTypeTransform; import org.apache.gluten.expression.ConverterUtils; import org.apache.gluten.expression.ConverterUtils$; import org.apache.gluten.substrait.type.TypeNode; @@ -80,7 +81,8 @@ public static long build( if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) { joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal(); } else { - joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(); + joinType = + JoinTypeTransform.toSubstraitJoinType(broadCastContext.joinType(), false).ordinal(); } return nativeBuild( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 4677a28e61f37..9884a0c6ef39f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -382,4 +383,28 @@ object CHBackendSettings extends BackendSettingsApi with Logging { override def supportCartesianProductExec(): Boolean = true + override def supportHashBuildJoinTypeOnLeft: JoinType => Boolean = { + t => + if (super.supportHashBuildJoinTypeOnLeft(t)) { + true + } else { + t match { + case LeftOuter => true + case _ => false + } + } + } + + override def supportHashBuildJoinTypeOnRight: JoinType => Boolean = { + t => + if (super.supportHashBuildJoinTypeOnRight(t)) { + true + } else { + t match { + case RightOuter => true + case _ => false + } + } + } + } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 03e5aaa538a9f..2ba047ba3f012 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ -import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, ReorderJoinTablesRule, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule} +import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule} import org.apache.gluten.extension.columnar.AddFallbackTagRule import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides import org.apache.gluten.extension.columnar.transition.Convention @@ -605,7 +605,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => RewriteSortMergeJoinToHashJoinRule(spark), spark => ReorderJoinTablesRule(spark)) + List(spark => RewriteSortMergeJoinToHashJoinRule(spark)) override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List() diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 7080e55dc1863..4ab03cc9bea0f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -17,6 +17,7 @@ package org.apache.gluten.execution import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.utils._ import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy} import org.apache.spark.{broadcast, SparkContext} @@ -41,29 +42,39 @@ object JoinTypeTransform { } } - def toSubstraitType(joinType: JoinType, buildSide: BuildSide): JoinRel.JoinType = { - joinType match { + def toSubstraitJoinType(sparkJoin: JoinType, needSwitchChildren: Boolean): JoinRel.JoinType = + sparkJoin match { case _: InnerLike => JoinRel.JoinType.JOIN_TYPE_INNER case FullOuter => JoinRel.JoinType.JOIN_TYPE_OUTER case LeftOuter => - JoinRel.JoinType.JOIN_TYPE_LEFT - case RightOuter if (buildSide == BuildLeft) => - // The tables order will be reversed in HashJoinLikeExecTransformer - JoinRel.JoinType.JOIN_TYPE_LEFT - case RightOuter if (buildSide == BuildRight) => - // This the case rewritten in ReorderJoinLeftRightRule - JoinRel.JoinType.JOIN_TYPE_RIGHT + if (needSwitchChildren) { + JoinRel.JoinType.JOIN_TYPE_RIGHT + } else { + JoinRel.JoinType.JOIN_TYPE_LEFT + } + case RightOuter => + if (needSwitchChildren) { + JoinRel.JoinType.JOIN_TYPE_LEFT + } else { + JoinRel.JoinType.JOIN_TYPE_RIGHT + } case LeftSemi | ExistenceJoin(_) => + if (needSwitchChildren) { + throw new IllegalArgumentException("LeftSemi join should not switch children") + } JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI case LeftAnti => + if (needSwitchChildren) { + throw new IllegalArgumentException("LeftAnti join should not switch children") + } JoinRel.JoinType.JOIN_TYPE_ANTI case _ => // TODO: Support cross join with Cross Rel JoinRel.JoinType.UNRECOGNIZED } - } + } case class CHShuffledHashJoinExecTransformer( @@ -103,7 +114,7 @@ case class CHShuffledHashJoinExecTransformer( } private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = - JoinTypeTransform.toSubstraitType(joinType, buildSide) + JoinTypeTransform.toSubstraitJoinType(joinType, needSwitchChildren) } case class CHBroadcastBuildSideRDD( @@ -211,5 +222,5 @@ case class CHBroadcastHashJoinExecTransformer( // and isExistenceJoin is set to true to indicate that it is an existence join. private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = - JoinTypeTransform.toSubstraitType(joinType, buildSide) + JoinTypeTransform.toSubstraitJoinType(joinType, false) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala deleted file mode 100644 index 4cedaae25684c..0000000000000 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.extension - -import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings -import org.apache.gluten.execution._ - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.optimizer._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive._ - -case class ReorderJoinTablesRule(session: SparkSession) extends Rule[SparkPlan] with Logging { - override def apply(plan: SparkPlan): SparkPlan = { - if (CHBackendSettings.enableReorderHashJoinTables) { - visitPlan(plan) - } else { - plan - } - } - - private def visitPlan(plan: SparkPlan): SparkPlan = { - plan match { - case hashShuffle: ColumnarShuffleExchangeExec => - hashShuffle.withNewChildren(hashShuffle.children.map(visitPlan)) - case hashJoin: CHShuffledHashJoinExecTransformer => - val newHashJoin = reorderHashJoin(hashJoin) - newHashJoin.withNewChildren(newHashJoin.children.map(visitPlan)) - case _ => - plan.withNewChildren(plan.children.map(visitPlan)) - } - } - - private def reorderHashJoin(hashJoin: CHShuffledHashJoinExecTransformer): SparkPlan = { - val leftQueryStageRow = childShuffleQueryStageRows(hashJoin.left) - val rightQueryStageRow = childShuffleQueryStageRows(hashJoin.right) - if (leftQueryStageRow == None || rightQueryStageRow == None) { - logError(s"Cannot reorder this hash join. Its children is not ShuffleQueryStageExec") - hashJoin - } else { - val threshold = CHBackendSettings.reorderHashJoinTablesThreshold - val isLeftLarger = leftQueryStageRow.get > rightQueryStageRow.get * threshold - val isRightLarger = leftQueryStageRow.get * threshold < rightQueryStageRow.get - hashJoin.joinType match { - case Inner => - if (isRightLarger && hashJoin.buildSide == BuildRight) { - CHShuffledHashJoinExecTransformer( - hashJoin.rightKeys, - hashJoin.leftKeys, - hashJoin.joinType, - hashJoin.buildSide, - hashJoin.condition, - hashJoin.right, - hashJoin.left, - hashJoin.isSkewJoin) - } else if (isLeftLarger && hashJoin.buildSide == BuildLeft) { - CHShuffledHashJoinExecTransformer( - hashJoin.leftKeys, - hashJoin.rightKeys, - hashJoin.joinType, - BuildRight, - hashJoin.condition, - hashJoin.left, - hashJoin.right, - hashJoin.isSkewJoin) - } else { - hashJoin - } - case LeftOuter => - // left outer + build right is the common case,other cases have not been covered by tests - // and don't reroder them. - if (isRightLarger && hashJoin.buildSide == BuildRight) { - CHShuffledHashJoinExecTransformer( - hashJoin.rightKeys, - hashJoin.leftKeys, - RightOuter, - BuildRight, - hashJoin.condition, - hashJoin.right, - hashJoin.left, - hashJoin.isSkewJoin) - } else { - hashJoin - } - case RightOuter => - // right outer + build left is the common case,other cases have not been covered by tests - // and don't reroder them. - if (isLeftLarger && hashJoin.buildSide == BuildLeft) { - CHShuffledHashJoinExecTransformer( - hashJoin.leftKeys, - hashJoin.rightKeys, - RightOuter, - BuildRight, - hashJoin.condition, - hashJoin.left, - hashJoin.right, - hashJoin.isSkewJoin) - } else if (isRightLarger && hashJoin.buildSide == BuildLeft) { - CHShuffledHashJoinExecTransformer( - hashJoin.rightKeys, - hashJoin.leftKeys, - LeftOuter, - BuildRight, - hashJoin.condition, - hashJoin.right, - hashJoin.left, - hashJoin.isSkewJoin) - } else { - hashJoin - } - case _ => hashJoin - } - } - } - - private def childShuffleQueryStageRows(plan: SparkPlan): Option[BigInt] = { - plan match { - case queryStage: ShuffleQueryStageExec => - queryStage.getRuntimeStatistics.rowCount - case _: ColumnarBroadcastExchangeExec => - None - case _: ColumnarShuffleExchangeExec => - None - case _ => - if (plan.children.length == 1) { - childShuffleQueryStageRows(plan.children.head) - } else { - None - } - } - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala index 86e6c1f412656..c0c88fb893278 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinExecTransformer.scala @@ -167,7 +167,8 @@ trait HashJoinLikeExecTransformer extends BaseJoinExec with TransformSupport { } } - protected lazy val substraitJoinType: JoinRel.JoinType = SubstraitUtil.toSubstrait(joinType) + protected lazy val substraitJoinType: JoinRel.JoinType = + SubstraitUtil.toSubstrait(joinType) override def metricsUpdater(): MetricsUpdater = BackendsApiManager.getMetricsApiInstance.genHashJoinTransformerMetricsUpdater(metrics) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala index a8cc791286b2a..2b61c6935b3e4 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/OffloadSingleNode.scala @@ -205,7 +205,13 @@ object OffloadJoin { case Some(join: Join) => val leftSize = join.left.stats.sizeInBytes val rightSize = join.right.stats.sizeInBytes - if (rightSize <= leftSize) BuildRight else BuildLeft + val leftRowCount = join.left.stats.rowCount + val rightRowCount = join.right.stats.rowCount + if (rightSize == leftSize && (rightRowCount.isDefined && leftRowCount.isDefined)) { + if (rightRowCount.get <= leftRowCount.get) BuildRight + else BuildLeft + } else if (rightSize <= leftSize) BuildRight + else BuildLeft // Only the ShuffledHashJoinExec generated directly in some spark tests is not link // logical plan, such as OuterJoinSuite. case _ => shj.buildSide