Skip to content

Commit

Permalink
[GLUTEN-6768][CH] Try to use multi join on clauses instead of inequal…
Browse files Browse the repository at this point in the history
… join condition (apache#6787)

What changes were proposed in this pull request?
(Please fill in changes proposed in this fix)

Fixes: apache#6768

Transform a join with inequal condition into multi join on clauses as possible, it could be more efficient. For example convert

on t1.key = t2.key and (t1.a1 = t2.a1 or t1.a2 = t1.a2 or t1.a3 = t2.a3)
to

on (t1.key = t2.key and t1.a1 = t2.a1) or (t1.key = t2.key and t1.a2 = t1.a2) or (t1.key = t2.key and t1.a3 = t2.a3)
We need to limit the right table size to avoid OOM, because we can only use hash join algorithm on multi join on clauses.

How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)

unit tests

(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)
  • Loading branch information
lgbo-ustc authored and shamirchen committed Oct 14, 2024
1 parent 0da79e7 commit 45457d8
Show file tree
Hide file tree
Showing 9 changed files with 460 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
condition: Option[Expression],
left: SparkPlan,
right: SparkPlan,
isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase =
isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = {
CHShuffledHashJoinExecTransformer(
leftKeys,
rightKeys,
Expand All @@ -319,6 +319,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
left,
right,
isSkewJoin)
}

/** Generate BroadcastHashJoinExecTransformer. */
def genBroadcastHashJoinExecTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy}

Expand All @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.JoinRel

object JoinTypeTransform {
Expand Down Expand Up @@ -104,6 +106,62 @@ case class CHShuffledHashJoinExecTransformer(
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType, buildSide)

override def genJoinParameters(): Any = {
val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "")

// Start with "JoinParameters:"
val joinParametersStr = new StringBuffer("JoinParameters:")
// isBHJ: 0 for SHJ, 1 for BHJ
// isNullAwareAntiJoin: 0 for false, 1 for true
// buildHashTableId: the unique id for the hash table of build plan
joinParametersStr
.append("isBHJ=")
.append(isBHJ)
.append("\n")
.append("isNullAwareAntiJoin=")
.append(isNullAwareAntiJoin)
.append("\n")
.append("buildHashTableId=")
.append(buildHashTableId)
.append("\n")
.append("isExistenceJoin=")
.append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
.append("\n")

CHAQEUtil.getShuffleQueryStageStats(streamedPlan) match {
case Some(stats) =>
joinParametersStr
.append("leftRowCount=")
.append(stats.rowCount.getOrElse(-1))
.append("\n")
.append("leftSizeInBytes=")
.append(stats.sizeInBytes)
.append("\n")
case _ =>
}
CHAQEUtil.getShuffleQueryStageStats(buildPlan) match {
case Some(stats) =>
joinParametersStr
.append("rightRowCount=")
.append(stats.rowCount.getOrElse(-1))
.append("\n")
.append("rightSizeInBytes=")
.append(stats.sizeInBytes)
.append("\n")
case _ =>
}
joinParametersStr
.append("numPartitions=")
.append(outputPartitioning.numPartitions)
.append("\n")

val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}
}

case class CHBroadcastBuildSideRDD(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive._

object CHAQEUtil {

// All TransformSupports have lost the logicalLink. So we need iterate the plan to find the
// first ShuffleQueryStageExec and get the runtime stats.
def getShuffleQueryStageStats(plan: SparkPlan): Option[Statistics] = {
plan match {
case stage: ShuffleQueryStageExec =>
Some(stage.getRuntimeStatistics)
case _ =>
if (plan.children.length == 1) {
getShuffleQueryStageStats(plan.children.head)
} else {
None
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class GlutenClickHouseColumnarShuffleAQESuite
override protected val tablesPath: String = basePath + "/tpch-data-ch"
override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch"
override protected val queriesResults: String = rootPath + "mergetree-queries-output"
private val backendConfigPrefix = "spark.gluten.sql.columnar.backend.ch."

/** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */
override protected def sparkConf: SparkConf = {
Expand Down Expand Up @@ -261,4 +262,48 @@ class GlutenClickHouseColumnarShuffleAQESuite
spark.sql("drop table t2")
}
}

test("GLUTEN-6768 change mixed join condition into multi join on clauses") {
withSQLConf(
(backendConfigPrefix + "runtime_config.prefer_multi_join_on_clauses", "true"),
(backendConfigPrefix + "runtime_config.multi_join_on_clauses_build_side_row_limit", "1000000")
) {

spark.sql("create table t1(a int, b int, c int, d int) using parquet")
spark.sql("create table t2(a int, b int, c int, d int) using parquet")

spark.sql("""
|insert into t1
|select id % 2 as a, id as b, id + 1 as c, id + 2 as d from range(1000)
|""".stripMargin)
spark.sql("""
|insert into t2
|select id % 2 as a, id as b, id + 1 as c, id + 2 as d from range(1000)
|""".stripMargin)

var sql = """
|select * from t1 join t2 on
|t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or t1.d = t2.d)
|order by t1.a, t1.b, t1.c, t1.d
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql = """
|select * from t1 join t2 on
|t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.c = t2.c and t1.d = t2.d))
|order by t1.a, t1.b, t1.c, t1.d
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

sql = """
|select * from t1 join t2 on
|t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.d = t2.d and t1.c >= t2.c))
|order by t1.a, t1.b, t1.c, t1.d
|""".stripMargin
compareResultsAgainstVanillaSpark(sql, true, { _ => })

spark.sql("drop table t1")
spark.sql("drop table t2")
}
}
}
21 changes: 21 additions & 0 deletions cpp-ch/local-engine/Common/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,27 @@ struct StreamingAggregateConfig
}
};

struct JoinConfig
{
/// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or t1.id2 = t2.id2)`, try to join with multi
/// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or t1.id2 = t2.id2)`
inline static const String PREFER_MULTI_JOIN_ON_CLAUSES = "prefer_multi_join_on_clauses";
/// Only hash join supports multi join on clauses, the right table cannot be too large. If the row number of right
/// table is larger then this limit, this transform will not work.
inline static const String MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT = "multi_join_on_clauses_build_side_row_limit";

bool prefer_multi_join_on_clauses = true;
size_t multi_join_on_clauses_build_side_rows_limit = 10000000;

static JoinConfig loadFromContext(DB::ContextPtr context)
{
JoinConfig config;
config.prefer_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_MULTI_JOIN_ON_CLAUSES, true);
config.multi_join_on_clauses_build_side_rows_limit = context->getConfigRef().getUInt64(MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT, 10000000);
return config;
}
};

struct ExecutorConfig
{
inline static const String DUMP_PIPELINE = "dump_pipeline";
Expand Down
23 changes: 23 additions & 0 deletions cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,24 @@ void tryAssign<bool>(const std::unordered_map<String, String> & kvs, const Strin
}
}

template<>
void tryAssign<Int64>(const std::unordered_map<String, String> & kvs, const String & key, Int64 & v)
{
auto it = kvs.find(key);
if (it != kvs.end())
{
try
{
v = std::stol(it->second);
}
catch (...)
{
LOG_ERROR(getLogger("tryAssign"), "Invalid number: {}", it->second);
throw;
}
}
}

template <char... chars>
void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf)
{
Expand Down Expand Up @@ -121,6 +139,11 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance)
tryAssign(kvs, "buildHashTableId", info.storage_join_key);
tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join);
tryAssign(kvs, "isExistenceJoin", info.is_existence_join);
tryAssign(kvs, "leftRowCount", info.left_table_rows);
tryAssign(kvs, "leftSizeInBytes", info.left_table_bytes);
tryAssign(kvs, "rightRowCount", info.right_table_rows);
tryAssign(kvs, "rightSizeInBytes", info.right_table_bytes);
tryAssign(kvs, "numPartitions", info.partitions_num);
return info;
}
}
Expand Down
5 changes: 5 additions & 0 deletions cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ struct JoinOptimizationInfo
bool is_smj = false;
bool is_null_aware_anti_join = false;
bool is_existence_join = false;
Int64 left_table_rows = -1;
Int64 left_table_bytes = -1;
Int64 right_table_rows = -1;
Int64 right_table_bytes = -1;
Int64 partitions_num = -1;
String storage_join_key;

static JoinOptimizationInfo parse(const String & advance);
Expand Down
Loading

0 comments on commit 45457d8

Please sign in to comment.