From 6d96819ef345af5677c5a47f17732916a222b76f Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 10 May 2024 09:10:55 +0800 Subject: [PATCH] support inequal join --- .../gluten/utils/CHJoinValidateUtil.scala | 49 +--- ...enClickHouseTPCHSaltNullParquetSuite.scala | 18 ++ cpp-ch/local-engine/Parser/JoinRelParser.cpp | 268 ++++++++++++------ cpp-ch/local-engine/Parser/JoinRelParser.h | 20 +- .../Parser/SerializedPlanParser.cpp | 2 +- 5 files changed, 208 insertions(+), 149 deletions(-) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala index 06b2445af6e1..40d682dd839f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHJoinValidateUtil.scala @@ -17,7 +17,7 @@ package org.apache.gluten.utils import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, In, LessThan, LessThanOrEqual, Not, Or} +import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression} import org.apache.spark.sql.catalyst.plans.JoinType /** @@ -61,53 +61,6 @@ object CHJoinValidateUtil extends Logging { return true } } - if (condition.isDefined) { - condition.get.transform { - case Or(l, r) => - if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) { - shouldFallback = true - } - Or(l, r) - case Not(EqualTo(l, r)) => - if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) { - shouldFallback = true - } - Not(EqualTo(l, r)) - case LessThan(l, r) => - if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) { - shouldFallback = true - } - LessThan(l, r) - case LessThanOrEqual(l, r) => - if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) { - shouldFallback = true - } - LessThanOrEqual(l, r) - case GreaterThan(l, r) => - if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) { - shouldFallback = true - } - GreaterThan(l, r) - case GreaterThanOrEqual(l, r) => - if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) { - shouldFallback = true - } - GreaterThanOrEqual(l, r) - case In(l, r) => - r.foreach( - e => { - if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, e)) { - shouldFallback = true - } - }) - In(l, r) - case EqualTo(l, r) => - if (hasTwoTableColumn(leftOutputSet, rightOutputSet, l, r)) { - shouldFallback = true - } - EqualTo(l, r) - } - } shouldFallback } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index a1bba300ed22..852a1907983f 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2550,5 +2550,23 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) spark.sql("drop table test_tbl_5096") } + + test("Inequal join support") { + withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) { + spark.sql("create table ineq_join_t1 (key bigint, value bigint) using parquet"); + spark.sql("create table ineq_join_t2 (key bigint, value bigint) using parquet"); + spark.sql("insert into ineq_join_t1 values(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)"); + spark.sql("insert into ineq_join_t2 values(2, 2), (2, 1), (3, 3), (4, 6), (5, 3)"); + val sql = + """ + | select t1.key, t1.value, t2.key, t2.value from ineq_join_t1 as t1 + | left join ineq_join_t2 as t2 + | on t1.key = t2.key and t1.value > t2.value + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + spark.sql("drop table ineq_join_t1") + spark.sql("drop table ineq_join_t2") + } + } } // scalastyle:on line.size.limit diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index 023a51552f82..f536e1286a67 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -31,6 +31,7 @@ #include #include + namespace DB { namespace ErrorCodes @@ -206,11 +207,26 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q after_join_names.insert(after_join_names.end(), left_names.begin(), left_names.end()); auto right_name = table_join->columnsFromJoinedTable().getNames(); after_join_names.insert(after_join_names.end(), right_name.begin(), right_name.end()); - bool add_filter_step = tryAddPushDownFilter(*table_join, join, *left, *right, table_join->columnsFromJoinedTable(), after_join_names); + + const auto & left_header = left->getCurrentDataStream().header; + const auto & right_header = right->getCurrentDataStream().header; QueryPlanPtr query_plan; + + /// Support only one join clause. + table_join->addDisjunct(); + /// some examples to explain when the post_join_filter is not empty + /// - on t1.key = t2.key and t1.v1 > 1 and t2.v1 > 1, 't1.v1> 1' is in the post filter. but 't2.v1 > 1' + /// will be pushed down into right table by spark and is not in the post filter. 't1.key = t2.key ' is + /// in JoinRel::expression. + /// - on t1.key = t2. key and t1.v1 > t2.v2, 't1.v1 > t2.v2' is in the post filter. + if (join.has_expression()) + collectJoinKeys(*table_join, join.expression(), left_header, right_header); if (storage_join) { + if (join.has_post_join_filter()) + applyJoinFilter(*table_join, join.post_join_filter(), *left, *right, false); + auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context); QueryPlanStepPtr join_step = std::make_unique(left->getCurrentDataStream(), broadcast_hash_join, 8192); @@ -223,6 +239,9 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } else if (join_opt_info.is_smj) { + if (join.has_post_join_filter()) + applyJoinFilter(*table_join, join.post_join_filter(), *left, *right, false); + JoinPtr smj_join = std::make_shared(table_join, right->getCurrentDataStream().header.cloneEmpty(), -1); MultiEnum join_algorithm = context->getSettingsRef().join_algorithm; QueryPlanStepPtr join_step @@ -239,9 +258,10 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } else { - /// TODO: make grace hash join be the default hash join algorithm. - /// - /// Following is some configuration for grace hash join. + if (join.has_post_join_filter()) + applyJoinFilter(*table_join, join.post_join_filter(), *left, *right, true); + + /// Following is some configurations for grace hash join. /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.join_algorithm=grace_hash. This will /// enable grace hash join. /// - spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_in_join=3145728. This setup @@ -277,10 +297,6 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q } reorderJoinOutput(*query_plan, after_join_names); - if (add_filter_step) - { - addPostFilter(*query_plan, join); - } return query_plan; } @@ -359,117 +375,185 @@ void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, } } -void JoinRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait::JoinRel & join) +/// Join keys are collected from substrait::JoinRel::expression() which only contains the equal join conditions. +void JoinRelParser::collectJoinKeys( + TableJoin & table_join, const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) { - std::string filter_name; - auto actions_dag = std::make_shared(query_plan.getCurrentDataStream().header.getColumnsWithTypeAndName()); - if (!join.post_join_filter().has_scalar_function()) - { - // It may be singular_or_list - auto * in_node = getPlanParser()->parseExpression(actions_dag, join.post_join_filter()); - filter_name = in_node->result_name; - } - else + auto & join_clause = table_join.getClauses().back(); + std::list expressions_stack; + expressions_stack.push_back(&expr); + while (!expressions_stack.empty()) { - getPlanParser()->parseFunction(query_plan.getCurrentDataStream().header, join.post_join_filter(), filter_name, actions_dag, true); + const auto * current_expr = expressions_stack.back(); + expressions_stack.pop_back(); + if (!current_expr->has_scalar_function()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Function expression is expected"); + auto function_name = parseFunctionName(current_expr->scalar_function().function_reference(), current_expr->scalar_function()); + if (!function_name) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid function expression"); + if (*function_name == "equals") + { + String left_key, right_key; + for (const auto & arg : current_expr->scalar_function().arguments()) + { + if (!arg.value().has_selection() || !arg.value().selection().has_direct_reference() + || !arg.value().selection().direct_reference().has_struct_field()) + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A column reference is expected"); + } + auto col_pos_ref = arg.value().selection().direct_reference().struct_field().field(); + if (col_pos_ref < left_header.columns()) + { + left_key = left_header.getByPosition(col_pos_ref).name; + } + else + { + right_key = right_header.getByPosition(col_pos_ref - left_header.columns()).name; + } + } + if (left_key.empty() || right_key.empty()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid key equal join condition"); + join_clause.addKey(left_key, right_key, false); + } + else if (*function_name == "and") + { + for (const auto & arg : current_expr->scalar_function().arguments()) + { + expressions_stack.push_back(&arg.value()); + } + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow function: {}", *function_name); + } } - auto filter_step = std::make_unique(query_plan.getCurrentDataStream(), actions_dag, filter_name, true); - filter_step->setStepDescription("Post Join Filter"); - steps.emplace_back(filter_step.get()); - query_plan.addStep(std::move(filter_step)); } -bool JoinRelParser::tryAddPushDownFilter( - TableJoin & table_join, - const substrait::JoinRel & join, - DB::QueryPlan & left, - DB::QueryPlan & right, - const NamesAndTypesList & alias_right, - const Names & names) +std::unordered_set JoinRelParser::extractTableSidesFromExpression(const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) { - try + std::unordered_set table_sides; + if (expr.has_scalar_function()) { - ASTParser astParser(context, function_mapping, getPlanParser()); - ASTs args; - - if (join.has_expression()) + for (const auto & arg : expr.scalar_function().arguments()) { - args.emplace_back(astParser.parseToAST(names, join.expression())); + auto table_sides_from_arg = extractTableSidesFromExpression(arg.value(), left_header, right_header); + table_sides.insert(table_sides_from_arg.begin(), table_sides_from_arg.end()); } - - if (join.has_post_join_filter()) + } + else if (expr.has_selection() && expr.selection().has_direct_reference() && expr.selection().direct_reference().has_struct_field()) + { + auto pos = expr.selection().direct_reference().struct_field().field(); + if (pos < left_header.columns()) { - args.emplace_back(astParser.parseToAST(names, join.post_join_filter())); + table_sides.insert(DB::JoinTableSide::Left); } + else + { + table_sides.insert(DB::JoinTableSide::Right); + } + } + else if (expr.has_singular_or_list()) + { + auto child_table_sides = extractTableSidesFromExpression(expr.singular_or_list().value(), left_header, right_header); + table_sides.insert(child_table_sides.begin(), child_table_sides.end()); + for (const auto & option : expr.singular_or_list().options()) + { + child_table_sides = extractTableSidesFromExpression(option, left_header, right_header); + table_sides.insert(child_table_sides.begin(), child_table_sides.end()); + } + } + else if (expr.has_literal()) + { + // nothing + } + else + { + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Illegal expression '{}'", expr.DebugString()); + } + return table_sides; +} - if (args.empty()) - return false; - - ASTPtr ast = args.size() == 1 ? args.back() : makeASTFunction("and", args); - - bool is_asof = (table_join.strictness() == JoinStrictness::Asof); - - Aliases aliases; - DatabaseAndTableWithAlias left_table_name; - DatabaseAndTableWithAlias right_table_name; - TableWithColumnNamesAndTypes left_table(left_table_name, left.getCurrentDataStream().header.getNamesAndTypesList()); - TableWithColumnNamesAndTypes right_table(right_table_name, alias_right); +void JoinRelParser::applyJoinFilter( + DB::TableJoin & table_join, const substrait::Expression & expr, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition) +{ + const auto & left_header = left.getCurrentDataStream().header; + const auto & right_header = right.getCurrentDataStream().header; + auto table_sides = extractTableSidesFromExpression(expr, left_header, right_header); - CollectJoinOnKeysVisitor::Data data{table_join, left_table, right_table, aliases, is_asof}; - if (auto * or_func = ast->as(); or_func && or_func->name == "or") + auto get_input_expressions = [](const DB::Block & header) + { + std::vector exprs; + for (size_t i = 0; i < header.columns(); ++i) { - for (auto & disjunct : or_func->arguments->children) - { - table_join.addDisjunct(); - CollectJoinOnKeysVisitor(data).visit(disjunct); - } - assert(table_join.getClauses().size() == or_func->arguments->children.size()); + substrait::Expression expr; + expr.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(i); + exprs.emplace_back(expr); + } + return exprs; + }; + + /// If the columns in the expression are all from one table, use analyzer_left_filter_condition_column_name + /// and analyzer_left_filter_condition_column_name to filt the join result data. It requires to build the filter + /// column at first. + /// If the columns in the expression are from both tables, use mixed_join_expression to filt the join result data. + /// the filter columns will be built inner the join step. + if (table_sides.size() == 1) + { + auto table_side = *table_sides.begin(); + if (table_side == DB::JoinTableSide::Left) + { + auto input_exprs = get_input_expressions(left_header); + input_exprs.push_back(expr); + auto actions_dag = expressionsToActionsDAG(input_exprs, left_header); + table_join.getClauses().back().analyzer_left_filter_condition_column_name = actions_dag->getOutputs().back()->result_name; + QueryPlanStepPtr before_join_step = std::make_unique(left.getCurrentDataStream(), actions_dag); + before_join_step->setStepDescription("Before JOIN LEFT"); + steps.emplace_back(before_join_step.get()); + left.addStep(std::move(before_join_step)); } else { - table_join.addDisjunct(); - CollectJoinOnKeysVisitor(data).visit(ast); - assert(table_join.oneDisjunct()); + auto input_exprs = get_input_expressions(right_header); + input_exprs.push_back(expr); + auto actions_dag = expressionsToActionsDAG(input_exprs, right_header); + table_join.getClauses().back().analyzer_right_filter_condition_column_name = actions_dag->getOutputs().back()->result_name; + QueryPlanStepPtr before_join_step = std::make_unique(left.getCurrentDataStream(), actions_dag); + before_join_step->setStepDescription("Before JOIN RIGHT"); + steps.emplace_back(before_join_step.get()); + right.addStep(std::move(before_join_step)); } - - if (join.has_post_join_filter()) + } + else if (table_sides.size() == 2) + { + if (!allow_mixed_condition) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Mixed join condition is not allowed"); + ColumnsWithTypeAndName mixed_columns; + std::unordered_set added_column_name; + for (const auto & col : left_header.getColumnsWithTypeAndName()) + { + mixed_columns.emplace_back(col); + added_column_name.insert(col.name); + } + for (const auto & col : right_header.getColumnsWithTypeAndName()) { - auto left_keys = table_join.leftKeysList(); - auto right_keys = table_join.rightKeysList(); - if (!left_keys->children.empty()) + if (added_column_name.find(col.name) == added_column_name.end()) { - auto actions = astParser.convertToActions(left.getCurrentDataStream().header.getNamesAndTypesList(), left_keys); - QueryPlanStepPtr before_join_step = std::make_unique(left.getCurrentDataStream(), actions); - before_join_step->setStepDescription("Before JOIN LEFT"); - steps.emplace_back(before_join_step.get()); - left.addStep(std::move(before_join_step)); + mixed_columns.emplace_back(col); } - - if (!right_keys->children.empty()) + else { - auto actions = astParser.convertToActions(right.getCurrentDataStream().header.getNamesAndTypesList(), right_keys); - QueryPlanStepPtr before_join_step = std::make_unique(right.getCurrentDataStream(), actions); - before_join_step->setStepDescription("Before JOIN RIGHT"); - steps.emplace_back(before_join_step.get()); - right.addStep(std::move(before_join_step)); + mixed_columns.emplace_back(col.column, col.type, col.name + "_right"); } } + DB::Block mixed_header(mixed_columns); + auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, mixed_header); + table_join.getMixedJoinExpression() + = std::make_shared(mixed_join_expressions_actions, ExpressionActionsSettings::fromContext(context)); } - // if ch does not support the join type or join conditions, it will throw an exception like 'not support'. - catch (Poco::Exception & e) + else { - // CH not support join condition has 'or' and has different table in each side. - // But in inner join, we could execute join condition after join. so we have add filter step - if (e.code() == ErrorCodes::INVALID_JOIN_ON_EXPRESSION && table_join.kind() == DB::JoinKind::Inner) - { - return true; - } - else - { - throw; - } + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition"); } - return false; } void registerJoinRelParser(RelParserFactory & factory) diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index 445b7e683300..99a309b0912a 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -51,14 +52,17 @@ class JoinRelParser : public RelParser DB::QueryPlanPtr parseJoin(const substrait::JoinRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right); void addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right); - bool tryAddPushDownFilter( - TableJoin & table_join, - const substrait::JoinRel & join, - DB::QueryPlan & left, - DB::QueryPlan & right, - const NamesAndTypesList & alias_right, - const Names & names); - void addPostFilter(DB::QueryPlan & plan, const substrait::JoinRel & join); + void collectJoinKeys( + TableJoin & table_join, const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); + void applyJoinFilter( + DB::TableJoin & table_join, + const substrait::Expression & expr, + DB::QueryPlan & left_plan, + DB::QueryPlan & right_plan, + bool allow_mixed_condition); + + static std::unordered_set extractTableSidesFromExpression( + const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); }; } diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index a26f78699dc8..b0d3bbeca962 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -214,7 +214,7 @@ std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( } } } - else if (expr.has_cast() || expr.has_if_then() || expr.has_literal()) + else if (expr.has_cast() || expr.has_if_then() || expr.has_literal() || expr.has_singular_or_list()) { const auto * node = parseExpression(actions_dag, expr); actions_dag->addOrReplaceInOutputs(*node);