Skip to content

Commit

Permalink
[GLUTEN-6768][CH] Try to reorder hash join tables based on AQE statis…
Browse files Browse the repository at this point in the history
…tics (#6770)

[CH] Try to reorder hash join tables based on AQE statistics
  • Loading branch information
lgbo-ustc authored Aug 14, 2024
1 parent db799a4 commit fc7f9cd
Show file tree
Hide file tree
Showing 5 changed files with 270 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,23 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
.getLong(GLUTEN_MAX_SHUFFLE_READ_BYTES, GLUTEN_MAX_SHUFFLE_READ_BYTES_DEFAULT)
}

// Reorder hash join tables, make sure to use the smaller table to build the hash table.
// Need to enable AQE
def enableReorderHashJoinTables(): Boolean = {
SparkEnv.get.conf.getBoolean(
"spark.gluten.sql.columnar.backend.ch.enable_reorder_hash_join_tables",
true
)
}
// The threshold to reorder hash join tables, if The result of dividing two tables' size is
// large then this threshold, reorder the tables. e.g. a/b > threshold or b/a > threshold
def reorderHashJoinTablesThreshold(): Int = {
SparkEnv.get.conf.getInt(
"spark.gluten.sql.columnar.backend.ch.reorder_hash_join_tables_thresdhold",
10
)
}

override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, ReorderJoinTablesRule, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
Expand Down Expand Up @@ -605,7 +605,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] =
List(spark => RewriteSortMergeJoinToHashJoinRule(spark))
List(spark => RewriteSortMergeJoinToHashJoinRule(spark), spark => ReorderJoinTablesRule(spark))

override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = {
List()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.{broadcast, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
Expand All @@ -41,14 +41,20 @@ object JoinTypeTransform {
}
}

def toSubstraitType(joinType: JoinType): JoinRel.JoinType = {
def toSubstraitType(joinType: JoinType, buildSide: BuildSide): JoinRel.JoinType = {
joinType match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
case FullOuter =>
JoinRel.JoinType.JOIN_TYPE_OUTER
case LeftOuter | RightOuter =>
case LeftOuter =>
JoinRel.JoinType.JOIN_TYPE_LEFT
case RightOuter if (buildSide == BuildLeft) =>
// The tables order will be reversed in HashJoinLikeExecTransformer
JoinRel.JoinType.JOIN_TYPE_LEFT
case RightOuter if (buildSide == BuildRight) =>
// This the case rewritten in ReorderJoinLeftRightRule
JoinRel.JoinType.JOIN_TYPE_RIGHT
case LeftSemi | ExistenceJoin(_) =>
JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case LeftAnti =>
Expand Down Expand Up @@ -97,7 +103,7 @@ case class CHShuffledHashJoinExecTransformer(
}
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType)
JoinTypeTransform.toSubstraitType(joinType, buildSide)
}

case class CHBroadcastBuildSideRDD(
Expand Down Expand Up @@ -205,5 +211,5 @@ case class CHBroadcastHashJoinExecTransformer(
// and isExistenceJoin is set to true to indicate that it is an existence join.
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType)
JoinTypeTransform.toSubstraitType(joinType, buildSide)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension

import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings
import org.apache.gluten.execution._

import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._

case class ReorderJoinTablesRule(session: SparkSession) extends Rule[SparkPlan] with Logging {
override def apply(plan: SparkPlan): SparkPlan = {
if (CHBackendSettings.enableReorderHashJoinTables) {
visitPlan(plan)
} else {
plan
}
}

private def visitPlan(plan: SparkPlan): SparkPlan = {
plan match {
case hashShuffle: ColumnarShuffleExchangeExec =>
hashShuffle.withNewChildren(hashShuffle.children.map(visitPlan))
case hashJoin: CHShuffledHashJoinExecTransformer =>
val newHashJoin = reorderHashJoin(hashJoin)
newHashJoin.withNewChildren(newHashJoin.children.map(visitPlan))
case _ =>
plan.withNewChildren(plan.children.map(visitPlan))
}
}

private def reorderHashJoin(hashJoin: CHShuffledHashJoinExecTransformer): SparkPlan = {
val leftQueryStageRow = childShuffleQueryStageRows(hashJoin.left)
val rightQueryStageRow = childShuffleQueryStageRows(hashJoin.right)
if (leftQueryStageRow == None || rightQueryStageRow == None) {
logError(s"Cannot reorder this hash join. Its children is not ShuffleQueryStageExec")
hashJoin
} else {
val threshold = CHBackendSettings.reorderHashJoinTablesThreshold
val isLeftLarger = leftQueryStageRow.get > rightQueryStageRow.get * threshold
val isRightLarger = leftQueryStageRow.get * threshold < rightQueryStageRow.get
hashJoin.joinType match {
case Inner =>
if (isRightLarger && hashJoin.buildSide == BuildRight) {
CHShuffledHashJoinExecTransformer(
hashJoin.rightKeys,
hashJoin.leftKeys,
hashJoin.joinType,
hashJoin.buildSide,
hashJoin.condition,
hashJoin.right,
hashJoin.left,
hashJoin.isSkewJoin)
} else if (isLeftLarger && hashJoin.buildSide == BuildLeft) {
CHShuffledHashJoinExecTransformer(
hashJoin.leftKeys,
hashJoin.rightKeys,
hashJoin.joinType,
BuildRight,
hashJoin.condition,
hashJoin.left,
hashJoin.right,
hashJoin.isSkewJoin)
} else {
hashJoin
}
case LeftOuter =>
// left outer + build right is the common case,other cases have not been covered by tests
// and don't reroder them.
if (isRightLarger && hashJoin.buildSide == BuildRight) {
CHShuffledHashJoinExecTransformer(
hashJoin.rightKeys,
hashJoin.leftKeys,
RightOuter,
BuildRight,
hashJoin.condition,
hashJoin.right,
hashJoin.left,
hashJoin.isSkewJoin)
} else {
hashJoin
}
case RightOuter =>
// right outer + build left is the common case,other cases have not been covered by tests
// and don't reroder them.
if (isLeftLarger && hashJoin.buildSide == BuildLeft) {
CHShuffledHashJoinExecTransformer(
hashJoin.leftKeys,
hashJoin.rightKeys,
RightOuter,
BuildRight,
hashJoin.condition,
hashJoin.left,
hashJoin.right,
hashJoin.isSkewJoin)
} else if (isRightLarger && hashJoin.buildSide == BuildLeft) {
CHShuffledHashJoinExecTransformer(
hashJoin.rightKeys,
hashJoin.leftKeys,
LeftOuter,
BuildRight,
hashJoin.condition,
hashJoin.right,
hashJoin.left,
hashJoin.isSkewJoin)
} else {
hashJoin
}
case _ => hashJoin
}
}
}

private def childShuffleQueryStageRows(plan: SparkPlan): Option[BigInt] = {
plan match {
case queryStage: ShuffleQueryStageExec =>
queryStage.getRuntimeStatistics.rowCount
case _: ColumnarBroadcastExchangeExec =>
None
case _: ColumnarShuffleExchangeExec =>
None
case _ =>
if (plan.children.length == 1) {
childShuffleQueryStageRows(plan.children.head)
} else {
None
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
package org.apache.gluten.execution

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.CoalescedPartitionSpec
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec}

class GlutenClickHouseColumnarShuffleAQESuite
extends GlutenClickHouseTPCHAbstractSuite
with AdaptiveSparkPlanHelper {
with AdaptiveSparkPlanHelper
with Logging {

override protected val tablesPath: String = basePath + "/tpch-data-ch"
override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch"
Expand Down Expand Up @@ -171,4 +176,89 @@ class GlutenClickHouseColumnarShuffleAQESuite
assert(adaptiveSparkPlanExec(1) == adaptiveSparkPlanExec(2))
}
}

test("GLUTEN-6768 rerorder hash join") {
withSQLConf(
("spark.gluten.sql.columnar.backend.ch.enable_reorder_hash_join_tables", "true"),
("spark.sql.adaptive.enabled", "true")) {
spark.sql("create table t1(a int, b int) using parquet")
spark.sql("create table t2(a int, b int) using parquet")

spark.sql("insert into t1 select id as a, id as b from range(100000)")
spark.sql("insert into t1 select id as a, id as b from range(100)")

def isExpectedJoinNode(plan: SparkPlan, joinType: JoinType, buildSide: BuildSide): Boolean = {
plan match {
case join: CHShuffledHashJoinExecTransformer =>
join.joinType == joinType && join.buildSide == buildSide
case _ => false
}
}

def collectExpectedJoinNode(
plan: SparkPlan,
joinType: JoinType,
buildSide: BuildSide): Seq[SparkPlan] = {
if (isExpectedJoinNode(plan, joinType, buildSide)) {
Seq(plan) ++ plan.children.flatMap(collectExpectedJoinNode(_, joinType, buildSide))
} else {
plan.children.flatMap(collectExpectedJoinNode(_, joinType, buildSide))
}
}

var sql = """
|select * from t2 left join t1 on t1.a = t2.a
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql,
true,
{
df =>
val joins = df.queryExecution.executedPlan.collect {
case adpativeNode: AdaptiveSparkPlanExec =>
collectExpectedJoinNode(adpativeNode.executedPlan, RightOuter, BuildRight)
case _ => Seq()
}
assert(joins.size == 1)
}
)

sql = """
|select * from t2 right join t1 on t1.a = t2.a
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql,
true,
{
df =>
val joins = df.queryExecution.executedPlan.collect {
case adpativeNode: AdaptiveSparkPlanExec =>
collectExpectedJoinNode(adpativeNode.executedPlan, LeftOuter, BuildRight)
case _ => Seq()
}
assert(joins.size == 1)
}
)

sql = """
|select * from t1 right join t2 on t1.a = t2.a
|""".stripMargin
compareResultsAgainstVanillaSpark(
sql,
true,
{
df =>
val joins = df.queryExecution.executedPlan.collect {
case adpativeNode: AdaptiveSparkPlanExec =>
collectExpectedJoinNode(adpativeNode.executedPlan, RightOuter, BuildRight)
case _ => Seq()
}
assert(joins.size == 1)
}
)

spark.sql("drop table t1")
spark.sql("drop table t2")
}
}
}

0 comments on commit fc7f9cd

Please sign in to comment.