Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
lgbo-ustc committed Aug 15, 2024
1 parent 8081b62 commit 647c504
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.adaptive._
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.vectorized.ColumnarBatch

Expand Down Expand Up @@ -115,9 +113,6 @@ case class CHShuffledHashJoinExecTransformer(
override def genJoinParameters(): Any = {
val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "")

// Don't use lef/right directly, they may be reordered in `HashJoinLikeExecTransformer`
val leftStats = getShuffleStageStatistics(streamedPlan)
val rightStats = getShuffleStageStatistics(buildPlan)
// Start with "JoinParameters:"
val joinParametersStr = new StringBuffer("JoinParameters:")
// isBHJ: 0 for SHJ, 1 for BHJ
Expand All @@ -138,14 +133,12 @@ case class CHShuffledHashJoinExecTransformer(
.append("\n")
logicalLink match {
case Some(join: Join) =>
val leftRowCount =
if (needSwitchChildren) join.left.stats.rowCount else join.right.stats.rowCount
val rightRowCount =
if (needSwitchChildren) join.right.stats.rowCount else join.left.stats.rowCount
val leftSizeInBytes =
if (needSwitchChildren) join.left.stats.sizeInBytes else join.right.stats.sizeInBytes
val rightSizeInBytes =
if (needSwitchChildren) join.right.stats.sizeInBytes else join.left.stats.sizeInBytes
val left = if (!needSwitchChildren) join.left else join.right
val right = if (!needSwitchChildren) join.right else join.left
val leftRowCount = left.stats.rowCount
val rightRowCount = right.stats.rowCount
val leftSizeInBytes = left.stats.sizeInBytes
val rightSizeInBytes = right.stats.sizeInBytes
val numPartitions = outputPartitioning.numPartitions
joinParametersStr
.append("leftRowCount=")
Expand All @@ -171,26 +164,6 @@ case class CHShuffledHashJoinExecTransformer(
.build()
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

private def getShuffleStageStatistics(plan: SparkPlan): ShuffleStageStaticstics = {
plan match {
case queryStage: ShuffleQueryStageExec =>
ShuffleStageStaticstics(
queryStage.shuffle.numPartitions,
queryStage.shuffle.numMappers,
queryStage.getRuntimeStatistics.rowCount)
case shuffle: ColumnarShuffleExchangeExec =>
// FIXEME: We cannot access shuffle.numPartitions and shuffle.numMappers here.
// Otherwise it will cause an exception `ProjectExecTransformer has column support mismatch`
ShuffleStageStaticstics(-1, -1, None)
case _ =>
if (plan.children.length == 1) {
getShuffleStageStatistics(plan.children.head)
} else {
ShuffleStageStaticstics(-1, -1, None)
}
}
}
}

case class CHBroadcastBuildSideRDD(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,10 +265,8 @@ class GlutenClickHouseColumnarShuffleAQESuite

test("GLUTEN-6768 change mixed join condition into multi join on clauses") {
withSQLConf(
(backendConfigPrefix + "runtime_config.prefer_inequal_join_to_multi_join_on_clauses", "true"),
(
backendConfigPrefix + "runtime_config.inequal_join_to_multi_join_on_clauses_row_limit",
"1000000")
(backendConfigPrefix + "runtime_config.prefer_multi_join_on_clauses", "true"),
(backendConfigPrefix + "runtime_config.multi_join_on_clauses_build_side_row_limit", "1000000")
) {

spark.sql("create table t1(a int, b int, c int, d int) using parquet")
Expand Down
8 changes: 4 additions & 4 deletions cpp-ch/local-engine/Parser/JoinRelParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q
google::protobuf::StringValue optimization_info;
optimization_info.ParseFromString(join.advanced_extension().optimization().value());
auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value());
LOG_ERROR(getLogger("JoinRelParser"), "{}", optimization_info.value());
LOG_DEBUG(getLogger("JoinRelParser"), "optimizaiton info:{}", optimization_info.value());
auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr;
if (storage_join)
{
Expand Down Expand Up @@ -315,7 +315,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q
if (table_join->getClauses().empty())
table_join->addDisjunct();
bool is_multi_join_on_clauses
= isJoinWithMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header);
= couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header);
if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0
&& join_opt_info.partitions_num > 0
&& join_opt_info.right_table_rows / join_opt_info.partitions_num
Expand Down Expand Up @@ -611,14 +611,14 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J
}

/// Only support following pattern: a1 = b1 or a2 = b2 or (a3 = b3 and a4 = b4)
bool JoinRelParser::isJoinWithMultiJoinOnClauses(
bool JoinRelParser::couldRewriteToMultiJoinOnClauses(
const DB::TableJoin::JoinOnClause & prefix_clause,
std::vector<DB::TableJoin::JoinOnClause> & clauses,
const substrait::JoinRel & join_rel,
const DB::Block & left_header,
const DB::Block & right_header)
{
/// There is only on join clause
/// There is only one join clause
if (!join_rel.has_post_join_filter())
return false;

Expand Down
2 changes: 1 addition & 1 deletion cpp-ch/local-engine/Parser/JoinRelParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class JoinRelParser : public RelParser
static std::unordered_set<DB::JoinTableSide> extractTableSidesFromExpression(
const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header);

bool isJoinWithMultiJoinOnClauses(
bool couldRewriteToMultiJoinOnClauses(
const DB::TableJoin::JoinOnClause & prefix_clause,
std::vector<DB::TableJoin::JoinOnClause> & clauses,
const substrait::JoinRel & join_rel,
Expand Down

0 comments on commit 647c504

Please sign in to comment.