Skip to content

Commit

Permalink
try to use multi join on clause as possible
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 12, 2024
1 parent 2e797f3 commit 81724ce
Show file tree
Hide file tree
Showing 6 changed files with 379 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.execution

import org.apache.gluten.backendsapi.BackendsApiManager
import org.apache.gluten.extension.ValidationResult
import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy}

Expand All @@ -25,10 +26,13 @@ import org.apache.spark.rpc.GlutenDriverEndpoint
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.BuildSide
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.JoinRel

object JoinTypeTransform {
Expand Down Expand Up @@ -60,6 +64,8 @@ object JoinTypeTransform {
}
}

case class ShuffleStageStaticstics(numPartitions: Int, numMappers: Int, rowCount: Option[BigInt])

case class CHShuffledHashJoinExecTransformer(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
Expand Down Expand Up @@ -98,6 +104,73 @@ case class CHShuffledHashJoinExecTransformer(
private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType)
override protected lazy val substraitJoinType: JoinRel.JoinType =
JoinTypeTransform.toSubstraitType(joinType)

override def genJoinParameters(): Any = {
val (isBHJ, isNullAwareAntiJoin, buildHashTableId) = genJoinParametersInternal()

// Don't use lef/right directly, they may be reordered in `HashJoinLikeExecTransformer`
val leftStats = getShuffleStageStatistics(streamedPlan)
val rightStats = getShuffleStageStatistics(buildPlan)
// Start with "JoinParameters:"
val joinParametersStr = new StringBuffer("JoinParameters:")
// isBHJ: 0 for SHJ, 1 for BHJ
// isNullAwareAntiJoin: 0 for false, 1 for true
// buildHashTableId: the unique id for the hash table of build plan
joinParametersStr
.append("isBHJ=")
.append(isBHJ)
.append("\n")
.append("isNullAwareAntiJoin=")
.append(isNullAwareAntiJoin)
.append("\n")
.append("buildHashTableId=")
.append(buildHashTableId)
.append("\n")
.append("isExistenceJoin=")
.append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0)
.append("\n")
.append("leftRowCount=")
.append(leftStats.rowCount.getOrElse(-1))
.append("\n")
.append("leftNumPartitions=")
.append(leftStats.numPartitions)
.append("\n")
.append("leftNumMappers=")
.append(leftStats.numMappers)
.append("\n")
.append("rightRowCount=")
.append(rightStats.rowCount.getOrElse(-1))
.append("\n")
.append("rightNumPartitions=")
.append(rightStats.numPartitions)
.append("\n")
.append("rightNumMappers=")
.append(rightStats.numMappers)
.append("\n")
val message = StringValue
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

private def getShuffleStageStatistics(plan: SparkPlan): ShuffleStageStaticstics = {
plan match {
case queryStage: ShuffleQueryStageExec =>
ShuffleStageStaticstics(
queryStage.shuffle.numPartitions,
queryStage.shuffle.numMappers,
queryStage.getRuntimeStatistics.rowCount)
case shuffle: ColumnarShuffleExchangeExec =>
ShuffleStageStaticstics(shuffle.numPartitions, shuffle.numMappers, None)
case _ =>
if (plan.children.length == 1) {
getShuffleStageStatistics(plan.children.head)
} else {
ShuffleStageStaticstics(-1, -1, None)
}
}
}
}

case class CHBroadcastBuildSideRDD(
Expand Down
21 changes: 21 additions & 0 deletions cpp-ch/local-engine/Common/GlutenConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,27 @@ struct StreamingAggregateConfig
}
};

struct JoinConfig
{
/// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or t1.id2 = t2.id2)`, try to join with multi
/// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or t1.id2 = t2.id2)`
inline static const String PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES = "prefer_inequal_join_to_multi_join_on_clauses";
/// Only hash join supports multi join on clauses, the right table cannot be to large. If the row number of right
/// table is larger then this limit, this transform will not work.
inline static const String INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT = "inequal_join_to_multi_join_on_clauses_row_limit";

bool prefer_inequal_join_to_multi_join_on_clauses = true;
size_t inequal_join_to_multi_join_on_clauses_rows_limit = 10000000;

static JoinConfig loadFromContext(DB::ContextPtr context)
{
JoinConfig config;
config.prefer_inequal_join_to_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES, true);
config.inequal_join_to_multi_join_on_clauses_rows_limit = context->getConfigRef().getUInt64(INEQUAL_JOIN_TO_MULTI_JOIN_ON_CLAUSES_ROWS_LIMIT, 10000000);
return config;
}
};

struct ExecutorConfig
{
inline static const String DUMP_PIPELINE = "dump_pipeline";
Expand Down
16 changes: 16 additions & 0 deletions cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ void tryAssign<bool>(const std::unordered_map<String, String> & kvs, const Strin
}
}

template<>
void tryAssign<Int64>(const std::unordered_map<String, String> & kvs, const String & key, Int64 & v)
{
auto it = kvs.find(key);
if (it != kvs.end())
{
v = std::stol(it->second);
}
}

template <char... chars>
void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf)
{
Expand Down Expand Up @@ -121,6 +131,12 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance)
tryAssign(kvs, "buildHashTableId", info.storage_join_key);
tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join);
tryAssign(kvs, "isExistenceJoin", info.is_existence_join);
tryAssign(kvs, "leftRowCount", info.left_table_rows);
tryAssign(kvs, "leftNumPartitions", info.left_table_partitions_num);
tryAssign(kvs, "leftNumMappers", info.left_table_mappers_num);
tryAssign(kvs, "rightRowCount", info.right_table_rows);
tryAssign(kvs, "rightNumPartitions", info.right_table_partitions_num);
tryAssign(kvs, "rightNumMappers", info.right_table_mappers_num);
return info;
}
}
Expand Down
6 changes: 6 additions & 0 deletions cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ struct JoinOptimizationInfo
bool is_smj = false;
bool is_null_aware_anti_join = false;
bool is_existence_join = false;
Int64 left_table_rows = -1;
Int64 left_table_partitions_num = -1;
Int64 left_table_mappers_num = -1;
Int64 right_table_rows = -1;
Int64 right_table_partitions_num = -1;
Int64 right_table_mappers_num = -1;
String storage_join_key;

static JoinOptimizationInfo parse(const String & advance);
Expand Down
Loading

0 comments on commit 81724ce

Please sign in to comment.