From 45457d82fe9160316fc6124304d5163a41a6f092 Mon Sep 17 00:00:00 2001 From: lgbo Date: Fri, 16 Aug 2024 11:48:14 +0800 Subject: [PATCH] [GLUTEN-6768][CH] Try to use multi join on clauses instead of inequal join condition (#6787) What changes were proposed in this pull request? (Please fill in changes proposed in this fix) Fixes: #6768 Transform a join with inequal condition into multi join on clauses as possible, it could be more efficient. For example convert on t1.key = t2.key and (t1.a1 = t2.a1 or t1.a2 = t1.a2 or t1.a3 = t2.a3) to on (t1.key = t2.key and t1.a1 = t2.a1) or (t1.key = t2.key and t1.a2 = t1.a2) or (t1.key = t2.key and t1.a3 = t2.a3) We need to limit the right table size to avoid OOM, because we can only use hash join algorithm on multi join on clauses. How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) unit tests (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) --- .../clickhouse/CHSparkPlanExecApi.scala | 3 +- .../execution/CHHashJoinExecTransformer.scala | 58 ++++ .../org/apache/gluten/utils/CHAQEUtil.scala | 39 +++ ...tenClickHouseColumnarShuffleAQESuite.scala | 45 +++ cpp-ch/local-engine/Common/GlutenConfig.h | 21 ++ .../Parser/AdvancedParametersParseUtil.cpp | 23 ++ .../Parser/AdvancedParametersParseUtil.h | 5 + cpp-ch/local-engine/Parser/JoinRelParser.cpp | 302 ++++++++++++++---- cpp-ch/local-engine/Parser/JoinRelParser.h | 19 ++ 9 files changed, 460 insertions(+), 55 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 03e5aaa538a9..5a49d6ea3d66 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -309,7 +309,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { condition: Option[Expression], left: SparkPlan, right: SparkPlan, - isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = + isSkewJoin: Boolean): ShuffledHashJoinExecTransformerBase = { CHShuffledHashJoinExecTransformer( leftKeys, rightKeys, @@ -319,6 +319,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { left, right, isSkewJoin) + } /** Generate BroadcastHashJoinExecTransformer. */ def genBroadcastHashJoinExecTransformer( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index 7080e55dc186..adb824804718 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.execution +import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.extension.ValidationResult import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy} @@ -29,6 +30,7 @@ import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch +import com.google.protobuf.{Any, StringValue} import io.substrait.proto.JoinRel object JoinTypeTransform { @@ -104,6 +106,62 @@ case class CHShuffledHashJoinExecTransformer( private val finalJoinType = JoinTypeTransform.toNativeJoinType(joinType) override protected lazy val substraitJoinType: JoinRel.JoinType = JoinTypeTransform.toSubstraitType(joinType, buildSide) + + override def genJoinParameters(): Any = { + val (isBHJ, isNullAwareAntiJoin, buildHashTableId): (Int, Int, String) = (0, 0, "") + + // Start with "JoinParameters:" + val joinParametersStr = new StringBuffer("JoinParameters:") + // isBHJ: 0 for SHJ, 1 for BHJ + // isNullAwareAntiJoin: 0 for false, 1 for true + // buildHashTableId: the unique id for the hash table of build plan + joinParametersStr + .append("isBHJ=") + .append(isBHJ) + .append("\n") + .append("isNullAwareAntiJoin=") + .append(isNullAwareAntiJoin) + .append("\n") + .append("buildHashTableId=") + .append(buildHashTableId) + .append("\n") + .append("isExistenceJoin=") + .append(if (joinType.isInstanceOf[ExistenceJoin]) 1 else 0) + .append("\n") + + CHAQEUtil.getShuffleQueryStageStats(streamedPlan) match { + case Some(stats) => + joinParametersStr + .append("leftRowCount=") + .append(stats.rowCount.getOrElse(-1)) + .append("\n") + .append("leftSizeInBytes=") + .append(stats.sizeInBytes) + .append("\n") + case _ => + } + CHAQEUtil.getShuffleQueryStageStats(buildPlan) match { + case Some(stats) => + joinParametersStr + .append("rightRowCount=") + .append(stats.rowCount.getOrElse(-1)) + .append("\n") + .append("rightSizeInBytes=") + .append(stats.sizeInBytes) + .append("\n") + case _ => + } + joinParametersStr + .append("numPartitions=") + .append(outputPartitioning.numPartitions) + .append("\n") + + val message = StringValue + .newBuilder() + .setValue(joinParametersStr.toString) + .build() + BackendsApiManager.getTransformerApiInstance.packPBMessage(message) + } } case class CHBroadcastBuildSideRDD( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala new file mode 100644 index 000000000000..9a35517f54fc --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHAQEUtil.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive._ + +object CHAQEUtil { + + // All TransformSupports have lost the logicalLink. So we need iterate the plan to find the + // first ShuffleQueryStageExec and get the runtime stats. + def getShuffleQueryStageStats(plan: SparkPlan): Option[Statistics] = { + plan match { + case stage: ShuffleQueryStageExec => + Some(stage.getRuntimeStatistics) + case _ => + if (plan.children.length == 1) { + getShuffleQueryStageStats(plan.children.head) + } else { + None + } + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala index fc22add2d880..10e5c7534d35 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseColumnarShuffleAQESuite.scala @@ -32,6 +32,7 @@ class GlutenClickHouseColumnarShuffleAQESuite override protected val tablesPath: String = basePath + "/tpch-data-ch" override protected val tpchQueries: String = rootPath + "queries/tpch-queries-ch" override protected val queriesResults: String = rootPath + "mergetree-queries-output" + private val backendConfigPrefix = "spark.gluten.sql.columnar.backend.ch." /** Run Gluten + ClickHouse Backend with ColumnarShuffleManager */ override protected def sparkConf: SparkConf = { @@ -261,4 +262,48 @@ class GlutenClickHouseColumnarShuffleAQESuite spark.sql("drop table t2") } } + + test("GLUTEN-6768 change mixed join condition into multi join on clauses") { + withSQLConf( + (backendConfigPrefix + "runtime_config.prefer_multi_join_on_clauses", "true"), + (backendConfigPrefix + "runtime_config.multi_join_on_clauses_build_side_row_limit", "1000000") + ) { + + spark.sql("create table t1(a int, b int, c int, d int) using parquet") + spark.sql("create table t2(a int, b int, c int, d int) using parquet") + + spark.sql(""" + |insert into t1 + |select id % 2 as a, id as b, id + 1 as c, id + 2 as d from range(1000) + |""".stripMargin) + spark.sql(""" + |insert into t2 + |select id % 2 as a, id as b, id + 1 as c, id + 2 as d from range(1000) + |""".stripMargin) + + var sql = """ + |select * from t1 join t2 on + |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or t1.d = t2.d) + |order by t1.a, t1.b, t1.c, t1.d + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select * from t1 join t2 on + |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.c = t2.c and t1.d = t2.d)) + |order by t1.a, t1.b, t1.c, t1.d + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + sql = """ + |select * from t1 join t2 on + |t1.a = t2.a and (t1.b = t2.b or t1.c = t2.c or (t1.d = t2.d and t1.c >= t2.c)) + |order by t1.a, t1.b, t1.c, t1.d + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + + spark.sql("drop table t1") + spark.sql("drop table t2") + } + } } diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index abb7295adc0d..84744dab21b8 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -92,6 +92,27 @@ struct StreamingAggregateConfig } }; +struct JoinConfig +{ + /// If the join condition is like `t1.k = t2.k and (t1.id1 = t2.id2 or t1.id2 = t2.id2)`, try to join with multi + /// join on clauses `(t1.k = t2.k and t1.id1 = t2.id2) or (t1.k = t2.k or t1.id2 = t2.id2)` + inline static const String PREFER_MULTI_JOIN_ON_CLAUSES = "prefer_multi_join_on_clauses"; + /// Only hash join supports multi join on clauses, the right table cannot be too large. If the row number of right + /// table is larger then this limit, this transform will not work. + inline static const String MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT = "multi_join_on_clauses_build_side_row_limit"; + + bool prefer_multi_join_on_clauses = true; + size_t multi_join_on_clauses_build_side_rows_limit = 10000000; + + static JoinConfig loadFromContext(DB::ContextPtr context) + { + JoinConfig config; + config.prefer_multi_join_on_clauses = context->getConfigRef().getBool(PREFER_MULTI_JOIN_ON_CLAUSES, true); + config.multi_join_on_clauses_build_side_rows_limit = context->getConfigRef().getUInt64(MULTI_JOIN_ON_CLAUSES_BUILD_SIDE_ROWS_LIMIT, 10000000); + return config; + } +}; + struct ExecutorConfig { inline static const String DUMP_PIPELINE = "dump_pipeline"; diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp index a7a07c0bf31a..42d4f4d4d8cd 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.cpp @@ -57,6 +57,24 @@ void tryAssign(const std::unordered_map & kvs, const Strin } } +template<> +void tryAssign(const std::unordered_map & kvs, const String & key, Int64 & v) +{ + auto it = kvs.find(key); + if (it != kvs.end()) + { + try + { + v = std::stol(it->second); + } + catch (...) + { + LOG_ERROR(getLogger("tryAssign"), "Invalid number: {}", it->second); + throw; + } + } +} + template void readStringUntilCharsInto(String & s, DB::ReadBuffer & buf) { @@ -121,6 +139,11 @@ JoinOptimizationInfo JoinOptimizationInfo::parse(const String & advance) tryAssign(kvs, "buildHashTableId", info.storage_join_key); tryAssign(kvs, "isNullAwareAntiJoin", info.is_null_aware_anti_join); tryAssign(kvs, "isExistenceJoin", info.is_existence_join); + tryAssign(kvs, "leftRowCount", info.left_table_rows); + tryAssign(kvs, "leftSizeInBytes", info.left_table_bytes); + tryAssign(kvs, "rightRowCount", info.right_table_rows); + tryAssign(kvs, "rightSizeInBytes", info.right_table_bytes); + tryAssign(kvs, "numPartitions", info.partitions_num); return info; } } diff --git a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h index 5a15a3ea8abc..5f6fe6d256e3 100644 --- a/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h +++ b/cpp-ch/local-engine/Parser/AdvancedParametersParseUtil.h @@ -29,6 +29,11 @@ struct JoinOptimizationInfo bool is_smj = false; bool is_null_aware_anti_join = false; bool is_existence_join = false; + Int64 left_table_rows = -1; + Int64 left_table_bytes = -1; + Int64 right_table_rows = -1; + Int64 right_table_bytes = -1; + Int64 partitions_num = -1; String storage_join_key; static JoinOptimizationInfo parse(const String & advance); diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index b217a9bd9da2..ef19e007d439 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -16,6 +16,8 @@ */ #include "JoinRelParser.h" +#include +#include #include #include #include @@ -25,15 +27,15 @@ #include #include #include -#include #include +#include #include #include #include #include #include #include -#include +#include #include @@ -42,9 +44,9 @@ namespace DB { namespace ErrorCodes { - extern const int LOGICAL_ERROR; - extern const int UNKNOWN_TYPE; - extern const int BAD_ARGUMENTS; +extern const int LOGICAL_ERROR; +extern const int UNKNOWN_TYPE; +extern const int BAD_ARGUMENTS; } } using namespace DB; @@ -98,7 +100,8 @@ DB::QueryPlanPtr JoinRelParser::parseOp(const substrait::Rel & rel, std::list JoinRelParser::extractTableSidesFromExpression(const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) +std::unordered_set JoinRelParser::extractTableSidesFromExpression( + const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) { std::unordered_set table_sides; if (expr.has_scalar_function()) @@ -169,8 +172,7 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ storage_join.getRightSampleBlock().getColumnsWithTypeAndName(), ActionsDAG::MatchColumnsMode::Position); - QueryPlanStepPtr right_project_step = - std::make_unique(right.getCurrentDataStream(), std::move(right_project)); + QueryPlanStepPtr right_project_step = std::make_unique(right.getCurrentDataStream(), std::move(right_project)); right_project_step->setStepDescription("Rename Broadcast Table Name"); steps.emplace_back(right_project_step.get()); right.addStep(std::move(right_project_step)); @@ -193,12 +195,9 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ } } ActionsDAG left_project = ActionsDAG::makeConvertingActions( - left.getCurrentDataStream().header.getColumnsWithTypeAndName(), - new_left_cols, - ActionsDAG::MatchColumnsMode::Position); + left.getCurrentDataStream().header.getColumnsWithTypeAndName(), new_left_cols, ActionsDAG::MatchColumnsMode::Position); - QueryPlanStepPtr left_project_step = - std::make_unique(left.getCurrentDataStream(), std::move(left_project)); + QueryPlanStepPtr left_project_step = std::make_unique(left.getCurrentDataStream(), std::move(left_project)); left_project_step->setStepDescription("Rename Left Table Name for broadcast join"); steps.emplace_back(left_project_step.get()); left.addStep(std::move(left_project_step)); @@ -206,9 +205,11 @@ void JoinRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & righ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right) { + auto join_config = JoinConfig::loadFromContext(getContext()); google::protobuf::StringValue optimization_info; optimization_info.ParseFromString(join.advanced_extension().optimization().value()); auto join_opt_info = JoinOptimizationInfo::parse(optimization_info.value()); + LOG_ERROR(getLogger("JoinRelParser"), "optimizaiton info:{}", optimization_info.value()); auto storage_join = join_opt_info.is_broadcast ? BroadCastJoinBuilder::getJoin(join_opt_info.storage_join_key) : nullptr; if (storage_join) { @@ -239,7 +240,9 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } if (is_col_names_changed) { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, + "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", left->getCurrentDataStream().header.dumpStructure(), right_header_before_convert_step.dumpStructure(), right->getCurrentDataStream().header.dumpStructure()); @@ -266,7 +269,6 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q if (storage_join) { - applyJoinFilter(*table_join, join, *left, *right, true); auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context); @@ -288,15 +290,13 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q /// TODO: make smj support mixed conditions if (need_post_filter && table_join->kind() != DB::JoinKind::Inner) { - throw DB::Exception( - DB::ErrorCodes::LOGICAL_ERROR, - "Sort merge join doesn't support mixed join conditions, except inner join."); + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Sort merge join doesn't support mixed join conditions, except inner join."); } JoinPtr smj_join = std::make_shared(table_join, right->getCurrentDataStream().header.cloneEmpty(), -1); MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; QueryPlanStepPtr join_step - = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), smj_join, 8192, 1, false); + = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), smj_join, 8192, 1, false); join_step->setStepDescription("SORT_MERGE_JOIN"); steps.emplace_back(join_step.get()); @@ -311,41 +311,22 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } else { - applyJoinFilter(*table_join, join, *left, *right, true); - - /// Following is some configurations for grace hash join. - /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash. This will - /// enable grace hash join. - /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728. This setup - /// the memory limitation fro grace hash join. If the memory consumption exceeds the limitation, - /// data will be spilled to disk. Don't set the limitation too small, otherwise the buckets number - /// will be too large and the performance will be bad. - JoinPtr hash_join = nullptr; - MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; - if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH)) + std::vector join_on_clauses; + if (table_join->getClauses().empty()) + table_join->addDisjunct(); + bool is_multi_join_on_clauses + = couldRewriteToMultiJoinOnClauses(table_join->getOnlyClause(), join_on_clauses, join, left_header, right_header); + if (is_multi_join_on_clauses && join_config.prefer_multi_join_on_clauses && join_opt_info.right_table_rows > 0 + && join_opt_info.partitions_num > 0 + && join_opt_info.right_table_rows / join_opt_info.partitions_num + < join_config.multi_join_on_clauses_build_side_rows_limit) { - hash_join = std::make_shared( - context, - table_join, - left->getCurrentDataStream().header, - right->getCurrentDataStream().header, - context->getTempDataOnDisk()); + query_plan = buildMultiOnClauseHashJoin(table_join, std::move(left), std::move(right), join_on_clauses); } else { - hash_join = std::make_shared(table_join, right->getCurrentDataStream().header.cloneEmpty()); + query_plan = buildSingleOnClauseHashJoin(join, table_join, std::move(left), std::move(right)); } - QueryPlanStepPtr join_step - = std::make_unique(left->getCurrentDataStream(), right->getCurrentDataStream(), hash_join, 8192, 1, false); - - join_step->setStepDescription("HASH_JOIN"); - steps.emplace_back(join_step.get()); - std::vector plans; - plans.emplace_back(std::move(left)); - plans.emplace_back(std::move(right)); - - query_plan = std::make_unique(); - query_plan->unitePlans(std::move(join_step), {std::move(plans)}); } JoinUtil::reorderJoinOutput(*query_plan, after_join_names); @@ -508,7 +489,11 @@ void JoinRelParser::collectJoinKeys( } bool JoinRelParser::applyJoinFilter( - DB::TableJoin & table_join, const substrait::JoinRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition) + DB::TableJoin & table_join, + const substrait::JoinRel & join_rel, + DB::QueryPlan & left, + DB::QueryPlan & right, + bool allow_mixed_condition) { if (!join_rel.has_post_join_filter()) return true; @@ -594,12 +579,13 @@ bool JoinRelParser::applyJoinFilter( return false; auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, mixed_header); mixed_join_expressions_actions.removeUnusedActions(); - table_join.getMixedJoinExpression() - = std::make_shared(std::move(mixed_join_expressions_actions), ExpressionActionsSettings::fromContext(context)); + table_join.getMixedJoinExpression() = std::make_shared( + std::move(mixed_join_expressions_actions), ExpressionActionsSettings::fromContext(context)); } else { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); + throw DB::Exception( + DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); } return true; } @@ -610,7 +596,7 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J ActionsDAG actions_dag{query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName()}; if (!join.post_join_filter().has_scalar_function()) { - // It may be singular_or_list + // It may be singular_or_list auto * in_node = getPlanParser()->parseExpression(actions_dag, join.post_join_filter()); filter_name = in_node->result_name; } @@ -624,6 +610,214 @@ void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::J query_plan.addStep(std::move(filter_step)); } +/// Only support following pattern: a1 = b1 or a2 = b2 or (a3 = b3 and a4 = b4) +bool JoinRelParser::couldRewriteToMultiJoinOnClauses( + const DB::TableJoin::JoinOnClause & prefix_clause, + std::vector & clauses, + const substrait::JoinRel & join_rel, + const DB::Block & left_header, + const DB::Block & right_header) +{ + /// There is only one join clause + if (!join_rel.has_post_join_filter()) + return false; + + const auto & filter_expr = join_rel.post_join_filter(); + std::list expression_stack; + expression_stack.push_back(&filter_expr); + + auto check_function = [&](const String function_name_, const substrait::Expression & e) + { + if (!e.has_scalar_function()) + { + return false; + } + auto function_name = parseFunctionName(e.scalar_function().function_reference(), e.scalar_function()); + return function_name.has_value() && *function_name == function_name_; + }; + + auto get_field_ref = [](const substrait::Expression & e) -> std::optional + { + if (e.has_selection() && e.selection().has_direct_reference() && e.selection().direct_reference().has_struct_field()) + { + return std::optional(e.selection().direct_reference().struct_field().field()); + } + return {}; + }; + + auto parse_join_keys = [&](const substrait::Expression & e) -> std::optional> + { + const auto & args = e.scalar_function().arguments(); + auto l_field_ref = get_field_ref(args[0].value()); + auto r_field_ref = get_field_ref(args[1].value()); + if (!l_field_ref.has_value() || !r_field_ref.has_value()) + return {}; + size_t l_pos = static_cast(*l_field_ref); + size_t r_pos = static_cast(*r_field_ref); + size_t l_cols = left_header.columns(); + size_t total_cols = l_cols + right_header.columns(); + + if (l_pos < l_cols && r_pos >= l_cols && r_pos < total_cols) + return std::make_pair(left_header.getByPosition(l_pos).name, right_header.getByPosition(r_pos - l_cols).name); + else if (r_pos < l_cols && l_pos >= l_cols && l_pos < total_cols) + return std::make_pair(left_header.getByPosition(r_pos).name, right_header.getByPosition(l_pos - l_cols).name); + return {}; + }; + + auto parse_and_expression = [&](const substrait::Expression & e, DB::TableJoin::JoinOnClause & join_on_clause) + { + std::vector and_expression_stack; + and_expression_stack.push_back(&e); + while (!and_expression_stack.empty()) + { + const auto & current_expr = *(and_expression_stack.back()); + and_expression_stack.pop_back(); + if (check_function("and", current_expr)) + { + for (const auto & arg : e.scalar_function().arguments()) + and_expression_stack.push_back(&arg.value()); + } + else if (check_function("equals", current_expr)) + { + auto optional_keys = parse_join_keys(current_expr); + if (!optional_keys) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not equal comparison for keys from both tables"); + return false; + } + join_on_clause.addKey(optional_keys->first, optional_keys->second, false); + } + else + { + LOG_ERROR(getLogger("JoinRelParser"), "And or equals function is expected"); + return false; + } + } + return true; + }; + + while (!expression_stack.empty()) + { + const auto & current_expr = *(expression_stack.back()); + expression_stack.pop_back(); + if (!check_function("or", current_expr)) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not an or expression"); + } + + auto get_current_join_on_clause = [&]() + { + DB::TableJoin::JoinOnClause new_clause = prefix_clause; + clauses.push_back(new_clause); + return &clauses.back(); + }; + + const auto & args = current_expr.scalar_function().arguments(); + for (const auto & arg : args) + { + if (check_function("equals", arg.value())) + { + auto optional_keys = parse_join_keys(arg.value()); + if (!optional_keys) + { + LOG_ERROR(getLogger("JoinRelParser"), "Not equal comparison for keys from both tables"); + return false; + } + get_current_join_on_clause()->addKey(optional_keys->first, optional_keys->second, false); + } + else if (check_function("and", arg.value())) + { + if (!parse_and_expression(arg.value(), *get_current_join_on_clause())) + { + LOG_ERROR(getLogger("JoinRelParser"), "Parse and expression failed"); + return false; + } + } + else if (check_function("or", arg.value())) + { + expression_stack.push_back(&arg.value()); + } + else + { + LOG_ERROR(getLogger("JoinRelParser"), "Unknow function"); + return false; + } + } + } + return true; +} + + +DB::QueryPlanPtr JoinRelParser::buildMultiOnClauseHashJoin( + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan, + const std::vector & join_on_clauses) +{ + DB::TableJoin::JoinOnClause & base_join_on_clause = table_join->getOnlyClause(); + base_join_on_clause = join_on_clauses[0]; + for (size_t i = 1; i < join_on_clauses.size(); ++i) + { + table_join->addDisjunct(); + auto & join_on_clause = table_join->getClauses().back(); + join_on_clause = join_on_clauses[i]; + } + + LOG_INFO(getLogger("JoinRelParser"), "multi join on clauses:\n{}", DB::TableJoin::formatClauses(table_join->getClauses())); + + JoinPtr hash_join = std::make_shared(table_join, right_plan->getCurrentDataStream().header); + QueryPlanStepPtr join_step + = std::make_unique(left_plan->getCurrentDataStream(), right_plan->getCurrentDataStream(), hash_join, 8192, 1, false); + join_step->setStepDescription("Multi join on clause hash join"); + steps.emplace_back(join_step.get()); + std::vector plans; + plans.emplace_back(std::move(left_plan)); + plans.emplace_back(std::move(right_plan)); + auto query_plan = std::make_unique(); + query_plan->unitePlans(std::move(join_step), {std::move(plans)}); + return query_plan; +} + +DB::QueryPlanPtr JoinRelParser::buildSingleOnClauseHashJoin( + const substrait::JoinRel & join_rel, std::shared_ptr table_join, DB::QueryPlanPtr left_plan, DB::QueryPlanPtr right_plan) +{ + applyJoinFilter(*table_join, join_rel, *left_plan, *right_plan, true); + /// Following is some configurations for grace hash join. + /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash. This will + /// enable grace hash join. + /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728. This setup + /// the memory limitation fro grace hash join. If the memory consumption exceeds the limitation, + /// data will be spilled to disk. Don't set the limitation too small, otherwise the buckets number + /// will be too large and the performance will be bad. + JoinPtr hash_join = nullptr; + MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; + if (join_algorithm.isSet(DB::JoinAlgorithm::GRACE_HASH)) + { + hash_join = std::make_shared( + context, + table_join, + left_plan->getCurrentDataStream().header, + right_plan->getCurrentDataStream().header, + context->getTempDataOnDisk()); + } + else + { + hash_join = std::make_shared(table_join, right_plan->getCurrentDataStream().header.cloneEmpty()); + } + QueryPlanStepPtr join_step + = std::make_unique(left_plan->getCurrentDataStream(), right_plan->getCurrentDataStream(), hash_join, 8192, 1, false); + + join_step->setStepDescription("HASH_JOIN"); + steps.emplace_back(join_step.get()); + std::vector plans; + plans.emplace_back(std::move(left_plan)); + plans.emplace_back(std::move(right_plan)); + + auto query_plan = std::make_unique(); + query_plan->unitePlans(std::move(join_step), {std::move(plans)}); + return query_plan; +} + void registerJoinRelParser(RelParserFactory & factory) { auto builder = [](SerializedPlanParser * plan_paser) { return std::make_shared(plan_paser); }; diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index ee1155cb4712..7e43187be308 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -70,6 +71,24 @@ class JoinRelParser : public RelParser static std::unordered_set extractTableSidesFromExpression( const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); + + bool couldRewriteToMultiJoinOnClauses( + const DB::TableJoin::JoinOnClause & prefix_clause, + std::vector & clauses, + const substrait::JoinRel & join_rel, + const DB::Block & left_header, + const DB::Block & right_header); + + DB::QueryPlanPtr buildMultiOnClauseHashJoin( + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan, + const std::vector & join_on_clauses); + DB::QueryPlanPtr buildSingleOnClauseHashJoin( + const substrait::JoinRel & join_rel, + std::shared_ptr table_join, + DB::QueryPlanPtr left_plan, + DB::QueryPlanPtr right_plan); }; }