From 647c5041935a7e7b2064de925e96bff542e5933c Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Thu, 15 Aug 2024 14:41:29 +0800 Subject: [PATCH] update --- .../execution/CHHashJoinExecTransformer.scala | 39 +++---------------- ...tenClickHouseColumnarShuffleAQESuite.scala | 6 +-- cpp-ch/local-engine/Parser/JoinRelParser.cpp | 8 ++-- cpp-ch/local-engine/Parser/JoinRelParser.h | 2 +- 4 files changed, 13 insertions(+), 42 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 252b9bc03fd68..990a0163ee826 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -27,9 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.adaptive._ import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch @@ -115,9 +113,6 @@ case class CHShuffledHashJoinExecTransformer( override def genJoinParameters(): Any = { val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") - // Don't use lef/right directly, they may be reordered in `HashJoinLikeExecTransformer` - val leftStats = getShuffleStageStatistics(streamedPlan) - val rightStats = getShuffleStageStatistics(buildPlan) // Start with "JoinParameters:" val joinParametersStr = new StringBuffer("JoinParameters:") // isBHJ: 0 for SHJ, 1 for BHJ @@ -138,14 +133,12 @@ case class CHShuffledHashJoinExecTransformer( .append("\n") logicalLink match { case Some(join: Join) => - val leftRowCount = - if (needSwitchChildren) join.left.stats.rowCount else join.right.stats.rowCount - val rightRowCount = - if (needSwitchChildren) join.right.stats.rowCount else join.left.stats.rowCount - val leftSizeInBytes = - if (needSwitchChildren) join.left.stats.sizeInBytes else join.right.stats.sizeInBytes - val rightSizeInBytes = - if (needSwitchChildren) join.right.stats.sizeInBytes else join.left.stats.sizeInBytes + val left = if (!needSwitchChildren) join.left else join.right + val right = if (!needSwitchChildren) join.right else join.left + val leftRowCount = left.stats.rowCount + val rightRowCount = right.stats.rowCount + val leftSizeInBytes = left.stats.sizeInBytes + val rightSizeInBytes = right.stats.sizeInBytes val numPartitions = outputPartitioning.numPartitions joinParametersStr .append("leftRowCount=") @@ -171,26 +164,6 @@ case class CHShuffledHashJoinExecTransformer( .build() BackendsApiManager.getTransformerApiInstance.packPBMessage(message) } - - private def getShuffleStageStatistics(plan: SparkPlan): ShuffleStageStaticstics = { - plan match { - case queryStage: ShuffleQueryStageExec => - ShuffleStageStaticstics( - queryStage.shuffle.numPartitions, - queryStage.shuffle.numMappers, - queryStage.getRuntimeStatistics.rowCount) - case shuffle: ColumnarShuffleExchangeExec => - // FIXEME: We cannot access shuffle.numPartitions and shuffle.numMappers here. - // Otherwise it will cause an exception `ProjectExecTransformer has column support mismatch` - ShuffleStageStaticstics(-1, -1, None) - case _ => - if (plan.children.length == 1) { - getShuffleStageStatistics(plan.children.head) - } else { - ShuffleStageStaticstics(-1, -1, None) - } - } - } } case class CHBroadcastBuildSideRDD( diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index ebeb69c16350a..10e5c7534d352 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -265,10 +265,8 @@ class GlutenClickHouseColumnarShuffleAQESuite test("GLUTEN-6768 change mixed join condition into multi join on clauses") { withSQLConf( - (backendConfigPrefix + "runtime_config.prefer_inequal_join_to_multi_join_on_clauses", "true"), - ( - backendConfigPrefix + "runtime_config.inequal_join_to_multi_join_on_clauses_row_limit", - "1000000") + (backendConfigPrefix + "runtime_config.prefer_multi_join_on_clauses", "true"), + (backendConfigPrefix + "runtime_config.multi_join_on_clauses_build_side_row_limit", "1000000") ) { spark.sql("create table t1(a int, b int, c int, d int) using parquet") diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index fbab2609412ab..3434cc3b0b94f 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -209,7 +209,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q google::protobuf::StringValue optimization_info; optimization_info.ParseFromString(join.advanced_extension().optimization().value()); auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value()); - LOG_ERROR(getLogger("JoinRelParser"), "{}", optimization_info.value()); + LOG_DEBUG(getLogger("JoinRelParser"), "optimizaiton info:{}", optimization_info.value()); auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr; if (storage_join) { @@ -315,7 +315,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q if (table_join->getClauses().empty()) table_join->addDisjunct(); bool is_multi_join_on_clauses - = isJoinWithMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); + = couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0 && join_opt_info.partitions_num > 0 && join_opt_info.right_table_rows / join_opt_info.partitions_num @@ -611,14 +611,14 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J } /// Only support following pattern: a1 = b1 or a2 = b2 or (a3 = b3 and a4 = b4) -bool JoinRelParser::isJoinWithMultiJoinOnClauses( +bool JoinRelParser::couldRewriteToMultiJoinOnClauses( const DB::TableJoin::JoinOnClause & prefix_clause, std::vector & clauses, const substrait::JoinRel & join_rel, const DB::Block & left_header, const DB::Block & right_header) { - /// There is only on join clause + /// There is only one join clause if (!join_rel.has_post_join_filter()) return false; diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index 0c0d07d6fdd27..7e43187be308b 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -72,7 +72,7 @@ class JoinRelParser : public RelParser static std::unordered_set extractTableSidesFromExpression( const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); - bool isJoinWithMultiJoinOnClauses( + bool couldRewriteToMultiJoinOnClauses( const DB::TableJoin::JoinOnClause & prefix_clause, std::vector & clauses, const substrait::JoinRel & join_rel,