Skip to content

Commit

Permalink
reorder hash join tables based on aqe
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 9, 2024
1 parent f7e59be commit b4e5420
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,24 @@ object CHBackendSettings extends BackendSettingsApi with Logging {
.getLong(GLUTEN_MAX_SHUFFLE_READ_BYTES, GLUTEN_MAX_SHUFFLE_READ_BYTES_DEFAULT)
}

// Reorder hash join tables, make sure to use the smaller table to build the hash table.
// Need to enable AQE
def enableReorderHashJoinTables(): Boolean = {
SparkEnv.get.conf.getBoolean(
"spark.gluten.sql.columnar.backend.ch.enable.reorder_hash_join_tables",
true
)
}
// The threshold to reorder hash join tables, if The result of dividing two tables' size is
// large then this threshold, reorder the tables. e.g. a/b > threshold or b/a > threshold
def reorderHashJoinTablesThreadhold(): Double = {
SparkEnv.get.conf.getInt(
"spark.gluten.sql.columnar.backend.ch.reorder_hash_join_tables_threadhold",
10
)
}


override def enableNativeWriteFiles(): Boolean = {
GlutenConfig.getConf.enableNativeWriter.getOrElse(false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.gluten.backendsapi.{BackendsApiManager, SparkPlanExecApi}
import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.execution._
import org.apache.gluten.expression._
import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.{CommonSubexpressionEliminateRule, CountDistinctWithoutExpand, FallbackBroadcastHashJoin, FallbackBroadcastHashJoinPrepQueryStage, ReorderJoinTablesRule, RewriteDateTimestampComparisonRule, RewriteSortMergeJoinToHashJoinRule, RewriteToDateExpresstionRule}
import org.apache.gluten.extension.columnar.AddFallbackTagRule
import org.apache.gluten.extension.columnar.MiscColumnarRules.TransformPreOverrides
import org.apache.gluten.extension.columnar.transition.Convention
Expand Down Expand Up @@ -601,7 +601,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
* @return
*/
override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] =
List(spark => RewriteSortMergeJoinToHashJoinRule(spark))
List(spark => RewriteSortMergeJoinToHashJoinRule(spark), spark => ReorderJoinTablesRule(spark))

override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = {
List()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.{broadcast, SparkContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
Expand All @@ -41,14 +41,20 @@ object JoinTypeTransform {
}
}

def toSubstraitType(joinType: JoinType): JoinRel.JoinType = {
def toSubstraitType(joinType: JoinType, buildSide: BuildSide): JoinRel.JoinType = {
joinType match {
case _: InnerLike =>
JoinRel.JoinType.JOIN_TYPE_INNER
case FullOuter =>
JoinRel.JoinType.JOIN_TYPE_OUTER
case LeftOuter | RightOuter =>
case LeftOuter =>
JoinRel.JoinType.JOIN_TYPE_LEFT
case RightOuter if (buildSide == BuildLeft) =>
// The tables order will be reversed in HashJoinLikeExecTransformer
JoinRel.JoinType.JOIN_TYPE_LEFT
case RightOuter if (buildSide == BuildRight) =>
// This the case rewritten in ReorderJoinLeftRightRule
JoinRel.JoinType.JOIN_TYPE_RIGHT
case LeftSemi | ExistenceJoin(_) =>
JoinRel.JoinType.JOIN_TYPE_LEFT_SEMI
case LeftAnti =>
Expand Down Expand Up @@ -97,7 +103,7 @@ case class CHShuffledHashJoinExecTransformer(
}
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType)
JoinTypeTransform.toSubstraitType(joinType, buildSide)
}

case class CHBroadcastBuildSideRDD(
Expand Down Expand Up @@ -205,5 +211,5 @@ case class CHBroadcastHashJoinExecTransformer(
// and isExistenceJoin is set to true to indicate that it is an existence join.
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType)
JoinTypeTransform.toSubstraitType(joinType, buildSide)
}
Original file line number Diff line number Diff line change
Expand Up @@ -531,5 +531,33 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
spark.sql("drop table t1")
spark.sql("drop table t2")
}

test("GLUTEN-6768 rerorder hash join") {
withSQLConf(("spark.gluten.sql.columnar.backend.ch.enable.reorder_hash_join_tables", "true")) {
spark.sql("create table t1(a int, b int) using parquet")
spark.sql("create table t2(a int, b int) using parquet")

spark.sql("insert into t1 select id as a, id as b from range(10000)")
spark.sql("insert into t1 select id as a, id as b from range(100)")

var sql = """
|select * from t2 left join t1 on t1.a = t2.a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql = """
|select * from t2 right join t1 on t1.a = t2.a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql = """
|select * from t1 right join t2 on t1.a = t2.a
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

spark.sql("drop table t1")
spark.sql("drop table t2")
}
}
}
// scalastyle:off line.size.limit

0 comments on commit b4e5420

Please sign in to comment.