From fc7f9cd0c46ab46ac69a5e208a4ae2763affec7f Mon Sep 17 00:00:00 2001 From: lgbo Date: Wed, 14 Aug 2024 10:26:43 +0800 Subject: [PATCH] [GLUTEN-6768][CH] Try to reorder hash join tables based on AQE statistics (#6770) [CH] Try to reorder hash join tables based on AQE statistics --- .../backendsapi/clickhouse/CHBackend.scala | 17 ++ .../clickhouse/CHSparkPlanExecApi.scala | 4 +- .../execution/CHHashJoinExecTransformer.scala | 16 +- .../extension/ReorderJoinTablesRule.scala | 149 ++++++++++++++++++ ...tenClickHouseColumnarShuffleAQESuite.scala | 92 ++++++++++- 5 files changed, 270 insertions(+), 8 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index e53f03300f32..4677a28e61f3 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -357,6 +357,23 @@ object CHBackendSettings extends BackendSettingsApi with Logging { .getLong(GLUTEN_MAX_SHUFFLE_READ_BYTES, GLUTEN_MAX_SHUFFLE_READ_BYTES_DEFAULT) } + // Reorder hash join tables, make sure to use the smaller table to build the hash table. + // Need to enable AQE + def enableReorderHashJoinTables(): Boolean = { + SparkEnv.get.conf.getBoolean( + "spark.gluten.sql.columnar.backend.ch.enable_reorder_hash_join_tables", + true + ) + } + // The threshold to reorder hash join tables, if The result of dividing two tables' size is + // large then this threshold, reorder the tables. e.g. a/b > threshold or b/a > threshold + def reorderHashJoinTablesThreshold(): Int = { + SparkEnv.get.conf.getInt( + "spark.gluten.sql.columnar.backend.ch.reorder_hash_join_tables_thresdhold", + 10 + ) + } + override def enableNativeWriteFiles(): Boolean = { GlutenConfig.getConf.enableNativeWriter.getOrElse(false) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 2ba047ba3f01..03e5aaa538a9 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi} import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.execution._ import org.apache.gluten.expression._ -import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule} +import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, ReorderJoinTablesRule, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule} import org.apache.gluten.extension.columnar.AddFallbackTagRule import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides import org.apache.gluten.extension.columnar.transition.Convention @@ -605,7 +605,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { * @return */ override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => RewriteSortMergeJoinToHashJoinRule(spark)) + List(spark => RewriteSortMergeJoinToHashJoinRule(spark), spark => ReorderJoinTablesRule(spark)) override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List() diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index ed946e1d263d..7080e55dc186 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -23,7 +23,7 @@ import org.apache.spark.{broadcast, SparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.rpc.GlutenDriverEndpoint import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation @@ -41,14 +41,20 @@ object JoinTypeTransform { } } - def toSubstraitType(joinType: JoinType): JoinRel.JoinType = { + def toSubstraitType(joinType: JoinType, buildSide: BuildSide): JoinRel.JoinType = { joinType match { case _: InnerLike => JoinRel.JoinType.JOIN_TYPE_INNER case FullOuter => JoinRel.JoinType.JOIN_TYPE_OUTER - case LeftOuter | RightOuter => + case LeftOuter => JoinRel.JoinType.JOIN_TYPE_LEFT + case RightOuter if (buildSide == BuildLeft) => + // The tables order will be reversed in HashJoinLikeExecTransformer + JoinRel.JoinType.JOIN_TYPE_LEFT + case RightOuter if (buildSide == BuildRight) => + // This the case rewritten in ReorderJoinLeftRightRule + JoinRel.JoinType.JOIN_TYPE_RIGHT case LeftSemi | ExistenceJoin(_) => JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI case LeftAnti => @@ -97,7 +103,7 @@ case class CHShuffledHashJoinExecTransformer( } private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = - JoinTypeTransform.toSubstraitType(joinType) + JoinTypeTransform.toSubstraitType(joinType, buildSide) } case class CHBroadcastBuildSideRDD( @@ -205,5 +211,5 @@ case class CHBroadcastHashJoinExecTransformer( // and isExistenceJoin is set to true to indicate that it is an existence join. private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = - JoinTypeTransform.toSubstraitType(joinType) + JoinTypeTransform.toSubstraitType(joinType, buildSide) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala new file mode 100644 index 000000000000..4cedaae25684 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/ReorderJoinTablesRule.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings +import org.apache.gluten.execution._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive._ + +case class ReorderJoinTablesRule(session: SparkSession) extends Rule[SparkPlan] with Logging { + override def apply(plan: SparkPlan): SparkPlan = { + if (CHBackendSettings.enableReorderHashJoinTables) { + visitPlan(plan) + } else { + plan + } + } + + private def visitPlan(plan: SparkPlan): SparkPlan = { + plan match { + case hashShuffle: ColumnarShuffleExchangeExec => + hashShuffle.withNewChildren(hashShuffle.children.map(visitPlan)) + case hashJoin: CHShuffledHashJoinExecTransformer => + val newHashJoin = reorderHashJoin(hashJoin) + newHashJoin.withNewChildren(newHashJoin.children.map(visitPlan)) + case _ => + plan.withNewChildren(plan.children.map(visitPlan)) + } + } + + private def reorderHashJoin(hashJoin: CHShuffledHashJoinExecTransformer): SparkPlan = { + val leftQueryStageRow = childShuffleQueryStageRows(hashJoin.left) + val rightQueryStageRow = childShuffleQueryStageRows(hashJoin.right) + if (leftQueryStageRow == None || rightQueryStageRow == None) { + logError(s"Cannot reorder this hash join. Its children is not ShuffleQueryStageExec") + hashJoin + } else { + val threshold = CHBackendSettings.reorderHashJoinTablesThreshold + val isLeftLarger = leftQueryStageRow.get > rightQueryStageRow.get * threshold + val isRightLarger = leftQueryStageRow.get * threshold < rightQueryStageRow.get + hashJoin.joinType match { + case Inner => + if (isRightLarger && hashJoin.buildSide == BuildRight) { + CHShuffledHashJoinExecTransformer( + hashJoin.rightKeys, + hashJoin.leftKeys, + hashJoin.joinType, + hashJoin.buildSide, + hashJoin.condition, + hashJoin.right, + hashJoin.left, + hashJoin.isSkewJoin) + } else if (isLeftLarger && hashJoin.buildSide == BuildLeft) { + CHShuffledHashJoinExecTransformer( + hashJoin.leftKeys, + hashJoin.rightKeys, + hashJoin.joinType, + BuildRight, + hashJoin.condition, + hashJoin.left, + hashJoin.right, + hashJoin.isSkewJoin) + } else { + hashJoin + } + case LeftOuter => + // left outer + build right is the common case,other cases have not been covered by tests + // and don't reroder them. + if (isRightLarger && hashJoin.buildSide == BuildRight) { + CHShuffledHashJoinExecTransformer( + hashJoin.rightKeys, + hashJoin.leftKeys, + RightOuter, + BuildRight, + hashJoin.condition, + hashJoin.right, + hashJoin.left, + hashJoin.isSkewJoin) + } else { + hashJoin + } + case RightOuter => + // right outer + build left is the common case,other cases have not been covered by tests + // and don't reroder them. + if (isLeftLarger && hashJoin.buildSide == BuildLeft) { + CHShuffledHashJoinExecTransformer( + hashJoin.leftKeys, + hashJoin.rightKeys, + RightOuter, + BuildRight, + hashJoin.condition, + hashJoin.left, + hashJoin.right, + hashJoin.isSkewJoin) + } else if (isRightLarger && hashJoin.buildSide == BuildLeft) { + CHShuffledHashJoinExecTransformer( + hashJoin.rightKeys, + hashJoin.leftKeys, + LeftOuter, + BuildRight, + hashJoin.condition, + hashJoin.right, + hashJoin.left, + hashJoin.isSkewJoin) + } else { + hashJoin + } + case _ => hashJoin + } + } + } + + private def childShuffleQueryStageRows(plan: SparkPlan): Option[BigInt] = { + plan match { + case queryStage: ShuffleQueryStageExec => + queryStage.getRuntimeStatistics.rowCount + case _: ColumnarBroadcastExchangeExec => + None + case _: ColumnarShuffleExchangeExec => + None + case _ => + if (plan.children.length == 1) { + childShuffleQueryStageRows(plan.children.head) + } else { + None + } + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index f25b8643b707..fc22add2d880 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -17,12 +17,17 @@ package org.apache.gluten.execution import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.optimizer._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.CoalescedPartitionSpec import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec} class GlutenClickHouseColumnarShuffleAQESuite extends GlutenClickHouseTPCHAbstractSuite - with AdaptiveSparkPlanHelper { + with AdaptiveSparkPlanHelper + with Logging { override protected val tablesPath: String = basePath + "/tpch-data-ch" override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch" @@ -171,4 +176,89 @@ class GlutenClickHouseColumnarShuffleAQESuite assert(adaptiveSparkPlanExec(1) == adaptiveSparkPlanExec(2)) } } + + test("GLUTEN-6768 rerorder hash join") { + withSQLConf( + ("spark.gluten.sql.columnar.backend.ch.enable_reorder_hash_join_tables", "true"), + ("spark.sql.adaptive.enabled", "true")) { + spark.sql("create table t1(a int, b int) using parquet") + spark.sql("create table t2(a int, b int) using parquet") + + spark.sql("insert into t1 select id as a, id as b from range(100000)") + spark.sql("insert into t1 select id as a, id as b from range(100)") + + def isExpectedJoinNode(plan: SparkPlan, joinType: JoinType, buildSide: BuildSide): Boolean = { + plan match { + case join: CHShuffledHashJoinExecTransformer => + join.joinType == joinType && join.buildSide == buildSide + case _ => false + } + } + + def collectExpectedJoinNode( + plan: SparkPlan, + joinType: JoinType, + buildSide: BuildSide): Seq[SparkPlan] = { + if (isExpectedJoinNode(plan, joinType, buildSide)) { + Seq(plan) ++ plan.children.flatMap(collectExpectedJoinNode(_, joinType, buildSide)) + } else { + plan.children.flatMap(collectExpectedJoinNode(_, joinType, buildSide)) + } + } + + var sql = """ + |select * from t2 left join t1 on t1.a = t2.a + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case adpativeNode: AdaptiveSparkPlanExec => + collectExpectedJoinNode(adpativeNode.executedPlan, RightOuter, BuildRight) + case _ => Seq() + } + assert(joins.size == 1) + } + ) + + sql = """ + |select * from t2 right join t1 on t1.a = t2.a + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case adpativeNode: AdaptiveSparkPlanExec => + collectExpectedJoinNode(adpativeNode.executedPlan, LeftOuter, BuildRight) + case _ => Seq() + } + assert(joins.size == 1) + } + ) + + sql = """ + |select * from t1 right join t2 on t1.a = t2.a + |""".stripMargin + compareResultsAgainstVanillaSpark( + sql, + true, + { + df => + val joins = df.queryExecution.executedPlan.collect { + case adpativeNode: AdaptiveSparkPlanExec => + collectExpectedJoinNode(adpativeNode.executedPlan, RightOuter, BuildRight) + case _ => Seq() + } + assert(joins.size == 1) + } + ) + + spark.sql("drop table t1") + spark.sql("drop table t2") + } + } }