diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java index 9cb49b6a2d30e..3782cd22a85e6 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/vectorized/StorageJoinBuilder.java @@ -74,12 +74,20 @@ public static long build( return converter.genColumnNameWithExprId(attr); }) .collect(Collectors.joining(",")); + + int joinType; + if (broadCastContext.buildHashTableId().startsWith("BuiltBNLJBroadcastTable-")) { + joinType = SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(); + } else { + joinType = SubstraitUtil.toCrossRelSubstrait(broadCastContext.joinType()).ordinal(); + } + return nativeBuild( broadCastContext.buildHashTableId(), batches, rowCount, joinKey, - SubstraitUtil.toSubstrait(broadCastContext.joinType()).ordinal(), + joinType, broadCastContext.hasMixedFiltCondition(), toNameStruct(output).toByteArray()); } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index cdca1b031a915..a9d9da7a9d42f 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, DenseRank, Expression, Lag, Lead, Literal, NamedExpression, Rank, RowNumber} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -297,4 +298,9 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } override def mergeTwoPhasesHashBaseAggregateIfNeed(): Boolean = true + + override def supportBroadcastNestedJoinJoinType: JoinType => Boolean = { + case _: InnerLike | LeftOuter | RightOuter | LeftSemi | FullOuter => true + case _ => false + } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala index 55f86bb32cce1..ac7cf67d8f306 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHBroadcastNestedLoopJoinExecTransformer.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} -import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashJoin} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.protobuf.{Any, StringValue} @@ -44,31 +43,7 @@ case class CHBroadcastNestedLoopJoinExecTransformer( condition ) { // Unique ID for builded table - lazy val buildBroadcastTableId: String = "BuiltBroadcastTable-" + buildPlan.id - - lazy val (buildKeyExprs, streamedKeyExprs) = { - require( - leftKeys.length == rightKeys.length && - leftKeys - .map(_.dataType) - .zip(rightKeys.map(_.dataType)) - .forall(types => sameType(types._1, types._2)), - "Join keys from two sides should have same length and types" - ) - // Spark has an improvement which would patch integer joins keys to a Long value. - // But this improvement would cause add extra project before hash join in velox, - // disabling this improvement as below would help reduce the project. - val (lkeys, rkeys) = if (BackendsApiManager.getSettings.enableJoinKeysRewrite()) { - (HashJoin.rewriteKeyExpr(leftKeys), HashJoin.rewriteKeyExpr(rightKeys)) - } else { - (leftKeys, rightKeys) - } - if (needSwitchChildren) { - (lkeys, rkeys) - } else { - (rkeys, lkeys) - } - } + lazy val buildBroadcastTableId: String = "BuiltBNLJBroadcastTable-" + buildPlan.id override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = { val streamedRDD = getColumnarInputRDDs(streamedPlan) @@ -106,27 +81,6 @@ case class CHBroadcastNestedLoopJoinExecTransformer( res } - def sameType(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (ArrayType(fromElement, _), ArrayType(toElement, _)) => - sameType(fromElement, toElement) - - case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => - sameType(fromKey, toKey) && - sameType(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { - case (l, r) => - l.name.equalsIgnoreCase(r.name) && - sameType(l.dataType, r.dataType) - } - - case (fromDataType, toDataType) => fromDataType == toDataType - } - } - override def genJoinParameters(): Any = { val joinParametersStr = new StringBuffer("JoinParameters:") joinParametersStr diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala index 0238508d96995..eab1551f730bb 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxBackend.scala @@ -511,8 +511,6 @@ object VeloxBackendSettings extends BackendSettingsApi { override def supportCartesianProductExec(): Boolean = true - override def supportBroadcastNestedLoopJoinExec(): Boolean = true - override def supportSampleExec(): Boolean = true override def supportColumnarArrowUdf(): Boolean = true diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index f1d3d54809353..92526206b3431 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -51,6 +51,7 @@ #include #include #include +#include #include #include #include @@ -60,7 +61,6 @@ #include #include #include -#include #include #include #include @@ -1077,4 +1077,53 @@ UInt64 MemoryUtil::getMemoryRSS() return rss * sysconf(_SC_PAGESIZE); } + +void JoinUtil::reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) +{ + ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); + NamesWithAliases project_cols; + for (const auto & col : cols) + { + project_cols.emplace_back(NameWithAlias(col, col)); + } + project->project(project_cols); + QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); + project_step->setStepDescription("Reorder Join Output"); + plan.addStep(std::move(project_step)); +} + +std::pair JoinUtil::getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type) +{ + switch (join_type) + { + case substrait::JoinRel_JoinType_JOIN_TYPE_INNER: + return {DB::JoinKind::Inner, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: + return {DB::JoinKind::Left, DB::JoinStrictness::Semi}; + case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: + return {DB::JoinKind::Left, DB::JoinStrictness::Anti}; + case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: + return {DB::JoinKind::Left, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: + return {DB::JoinKind::Right, DB::JoinStrictness::All}; + case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: + return {DB::JoinKind::Full, DB::JoinStrictness::All}; + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); + } +} + +std::pair JoinUtil::getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type) +{ + switch (join_type) + { + case substrait::CrossRel_JoinType_JOIN_TYPE_INNER: + case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: + case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: + return {DB::JoinKind::Cross, DB::JoinStrictness::All}; + default: + throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); + } +} + } diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index ff25415c0cbb7..938ca9d114896 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -16,6 +16,7 @@ * limitations under the License. */ #pragma once + #include #include #include @@ -25,6 +26,8 @@ #include #include #include +#include +#include #include namespace DB @@ -302,4 +305,12 @@ class ConcurrentDeque mutable std::mutex mtx; }; +class JoinUtil +{ +public: + static void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols); + static std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); + static std::pair getCrossJoinKindAndStrictness(substrait::CrossRel_JoinType join_type); +}; + } diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp index 3f3c7e6c32aa8..f0c9612dc5670 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.cpp @@ -97,7 +97,7 @@ std::shared_ptr buildJoin( DB::ReadBuffer & input, jlong row_count, const std::string & join_keys, - substrait::JoinRel_JoinType join_type, + jint join_type, bool has_mixed_join_condition, const std::string & named_struct) { @@ -109,7 +109,11 @@ std::shared_ptr buildJoin( DB::JoinKind kind; DB::JoinStrictness strictness; - std::tie(kind, strictness) = getJoinKindAndStrictness(join_type); + if (key.starts_with("BuiltBNLJBroadcastTable-")) + std::tie(kind, strictness) = JoinUtil::getCrossJoinKindAndStrictness(static_cast(join_type)); + else + std::tie(kind, strictness) = JoinUtil::getJoinKindAndStrictness(static_cast(join_type)); + substrait::NamedStruct substrait_struct; substrait_struct.ParseFromString(named_struct); diff --git a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h index 9a6837e35a0ac..3d2e67f9df101 100644 --- a/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h +++ b/cpp-ch/local-engine/Join/BroadCastJoinBuilder.h @@ -35,7 +35,7 @@ std::shared_ptr buildJoin( DB::ReadBuffer & input, jlong row_count, const std::string & join_keys, - substrait::JoinRel_JoinType join_type, + jint join_type, bool has_mixed_join_condition, const std::string & named_struct); void cleanBuildHashTable(const std::string & hash_table_id, jlong instance); diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.cpp b/cpp-ch/local-engine/Parser/CrossRelParser.cpp index 071588a9f26e1..2b573dc5055dc 100644 --- a/cpp-ch/local-engine/Parser/CrossRelParser.cpp +++ b/cpp-ch/local-engine/Parser/CrossRelParser.cpp @@ -15,10 +15,10 @@ * limitations under the License. */ #include "CrossRelParser.h" + #include #include #include -#include #include #include #include @@ -30,8 +30,7 @@ #include #include #include - -#include +#include #include @@ -45,24 +44,17 @@ namespace ErrorCodes } } -struct JoinOptimizationInfo -{ - bool is_broadcast = false; - bool is_smj = false; - bool is_null_aware_anti_join = false; - bool is_existence_join = false; - std::string storage_join_key; -}; - using namespace DB; -String parseJoinOptimizationInfos(const substrait::CrossRel & join) + + +namespace local_engine +{ +String parseCrossJoinOptimizationInfos(const substrait::CrossRel & join) { google::protobuf::StringValue optimization; optimization.ParseFromString(join.advanced_extension().optimization().value()); - JoinOptimizationInfo info; - auto a = optimization.value(); String storage_join_key; ReadBufferFromString in(optimization.value()); assertString("JoinParameters:", in); @@ -71,49 +63,13 @@ String parseJoinOptimizationInfos(const substrait::CrossRel & join) return storage_join_key; } -void reorderJoinOutput2(DB::QueryPlan & plan, DB::Names cols) -{ - ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); - NamesWithAliases project_cols; - for (const auto & col : cols) - { - project_cols.emplace_back(NameWithAlias(col, col)); - } - project->project(project_cols); - QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); - project_step->setStepDescription("Reorder Join Output"); - plan.addStep(std::move(project_step)); -} - -namespace local_engine -{ - -std::pair getJoinKindAndStrictness2(substrait::CrossRel_JoinType join_type) -{ - switch (join_type) - { - case substrait::CrossRel_JoinType_JOIN_TYPE_INNER: - case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: - case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: - return {DB::JoinKind::Cross, DB::JoinStrictness::All}; - // case substrait::CrossRel_JoinType_JOIN_TYPE_LEFT: - // return {DB::JoinKind::Left, DB::JoinStrictness::All}; - // - // case substrait::CrossRel_JoinType_JOIN_TYPE_RIGHT: - // return {DB::JoinKind::Right, DB::JoinStrictness::All}; - // case substrait::CrossRel_JoinType_JOIN_TYPE_OUTER: - // return {DB::JoinKind::Full, DB::JoinStrictness::All}; - default: - throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); - } -} -std::shared_ptr createDefaultTableJoin2(substrait::CrossRel_JoinType join_type) +std::shared_ptr createCrossTableJoin(substrait::CrossRel_JoinType join_type) { auto & global_context = SerializedPlanParser::global_context; auto table_join = std::make_shared( global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); - std::pair kind_and_strictness = getJoinKindAndStrictness2(join_type); + std::pair kind_and_strictness = JoinUtil::getCrossJoinKindAndStrictness(join_type); table_join->setKind(kind_and_strictness.first); table_join->setStrictness(kind_and_strictness.second); return table_join; @@ -154,68 +110,6 @@ DB::QueryPlanPtr CrossRelParser::parseOp(const substrait::Rel & rel, std::list CrossRelParser::extractTableSidesFromExpression(const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header) -{ - std::unordered_set table_sides; - if (expr.has_scalar_function()) - { - for (const auto & arg : expr.scalar_function().arguments()) - { - auto table_sides_from_arg = extractTableSidesFromExpression(arg.value(), left_header, right_header); - table_sides.insert(table_sides_from_arg.begin(), table_sides_from_arg.end()); - } - } - else if (expr.has_selection() && expr.selection().has_direct_reference() && expr.selection().direct_reference().has_struct_field()) - { - auto pos = expr.selection().direct_reference().struct_field().field(); - if (pos < left_header.columns()) - { - table_sides.insert(DB::JoinTableSide::Left); - } - else - { - table_sides.insert(DB::JoinTableSide::Right); - } - } - else if (expr.has_singular_or_list()) - { - auto child_table_sides = extractTableSidesFromExpression(expr.singular_or_list().value(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - for (const auto & option : expr.singular_or_list().options()) - { - child_table_sides = extractTableSidesFromExpression(option, left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - } - } - else if (expr.has_cast()) - { - auto child_table_sides = extractTableSidesFromExpression(expr.cast().input(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - } - else if (expr.has_if_then()) - { - for (const auto & if_child : expr.if_then().ifs()) - { - auto child_table_sides = extractTableSidesFromExpression(if_child.if_(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - child_table_sides = extractTableSidesFromExpression(if_child.then(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - } - auto child_table_sides = extractTableSidesFromExpression(expr.if_then().else_(), left_header, right_header); - table_sides.insert(child_table_sides.begin(), child_table_sides.end()); - } - else if (expr.has_literal()) - { - // nothing - } - else - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Illegal expression '{}'", expr.DebugString()); - } - return table_sides; -} - - void CrossRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join) { /// To support mixed join conditions, we must make sure that the column names in the right be the same as @@ -266,23 +160,20 @@ void CrossRelParser::renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & rig DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB::QueryPlanPtr left, DB::QueryPlanPtr right) { - auto storage_join_key = parseJoinOptimizationInfos(join); + auto storage_join_key = parseCrossJoinOptimizationInfos(join); auto storage_join = BroadCastJoinBuilder::getJoin(storage_join_key) ; renamePlanColumns(*left, *right, *storage_join); - auto table_join = createDefaultTableJoin2(join.type()); + auto table_join = createCrossTableJoin(join.type()); DB::Block right_header_before_convert_step = right->getCurrentDataStream().header; addConvertStep(*table_join, *left, *right); // Add a check to find error easily. - if (storage_join) + if(!blocksHaveEqualStructure(right_header_before_convert_step, right->getCurrentDataStream().header)) { - if(!blocksHaveEqualStructure(right_header_before_convert_step, right->getCurrentDataStream().header)) - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", - left->getCurrentDataStream().header.dumpNames(), - right_header_before_convert_step.dumpNames(), - right->getCurrentDataStream().header.dumpNames()); - } + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "For broadcast join, we must not change the columns name in the right table.\nleft header:{},\nright header: {} -> {}", + left->getCurrentDataStream().header.dumpNames(), + right_header_before_convert_step.dumpNames(), + right->getCurrentDataStream().header.dumpNames()); } Names after_join_names; @@ -295,7 +186,6 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB: auto right_header = right->getCurrentDataStream().header; QueryPlanPtr query_plan; - // applyJoinFilter(*table_join, join, *left, *right, true); table_join->addDisjunct(); auto broadcast_hash_join = storage_join->getJoinLocked(table_join, context); // table_join->resetKeys(); @@ -309,7 +199,7 @@ DB::QueryPlanPtr CrossRelParser::parseJoin(const substrait::CrossRel & join, DB: extra_plan_holder.emplace_back(std::move(right)); addPostFilter(*query_plan, join); - reorderJoinOutput2(*query_plan, after_join_names); + JoinUtil::reorderJoinOutput(*query_plan, after_join_names); return query_plan; } @@ -327,6 +217,7 @@ void CrossRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait:: { // It may be singular_or_list auto * in_node = getPlanParser()->parseExpression(actions_dag, expression); + auto a = isColumnConst(*in_node->column); filter_name = in_node->result_name; } else @@ -339,102 +230,6 @@ void CrossRelParser::addPostFilter(DB::QueryPlan & query_plan, const substrait:: query_plan.addStep(std::move(filter_step)); } -bool CrossRelParser::applyJoinFilter( - DB::TableJoin & table_join, const substrait::CrossRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition) -{ - if (!join_rel.has_expression()) - return true; - const auto & expr = join_rel.expression(); - - const auto & left_header = left.getCurrentDataStream().header; - const auto & right_header = right.getCurrentDataStream().header; - ColumnsWithTypeAndName mixed_columns; - std::unordered_set added_column_name; - for (const auto & col : left_header.getColumnsWithTypeAndName()) - { - mixed_columns.emplace_back(col); - added_column_name.insert(col.name); - } - for (const auto & col : right_header.getColumnsWithTypeAndName()) - { - const auto & renamed_col_name = table_join.renamedRightColumnNameWithAlias(col.name); - if (added_column_name.find(col.name) != added_column_name.end()) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Right column's name conflict with left column: {}", col.name); - mixed_columns.emplace_back(col); - added_column_name.insert(col.name); - } - DB::Block mixed_header(mixed_columns); - - auto table_sides = extractTableSidesFromExpression(expr, left_header, right_header); - - auto get_input_expressions = [](const DB::Block & header) - { - std::vector exprs; - for (size_t i = 0; i < header.columns(); ++i) - { - substrait::Expression expr; - expr.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(i); - exprs.emplace_back(expr); - } - return exprs; - }; - - /// If the columns in the expression are all from one table, use analyzer_left_filter_condition_column_name - /// and analyzer_left_filter_condition_column_name to filt the join result data. It requires to build the filter - /// column at first. - /// If the columns in the expression are from both tables, use mixed_join_expression to filt the join result data. - /// the filter columns will be built inner the join step. - if (table_sides.size() == 1) - { - auto table_side = *table_sides.begin(); - if (table_side == DB::JoinTableSide::Left) - { - auto input_exprs = get_input_expressions(left_header); - input_exprs.push_back(expr); - auto actions_dag = expressionsToActionsDAG(input_exprs, left_header); - table_join.getClauses().back().analyzer_left_filter_condition_column_name = actions_dag->getOutputs().back()->result_name; - QueryPlanStepPtr before_join_step = std::make_unique(left.getCurrentDataStream(), actions_dag); - before_join_step->setStepDescription("Before JOIN LEFT"); - steps.emplace_back(before_join_step.get()); - left.addStep(std::move(before_join_step)); - } - else - { - /// since the field reference in expr is the index of left_header ++ right_header, so we use - /// mixed_header to build the actions_dag - auto input_exprs = get_input_expressions(mixed_header); - input_exprs.push_back(expr); - auto actions_dag = expressionsToActionsDAG(input_exprs, mixed_header); - - /// clear unused columns in actions_dag - for (const auto & col : left_header.getColumnsWithTypeAndName()) - { - actions_dag->removeUnusedResult(col.name); - } - actions_dag->removeUnusedActions(); - - table_join.getClauses().back().analyzer_right_filter_condition_column_name = actions_dag->getOutputs().back()->result_name; - QueryPlanStepPtr before_join_step = std::make_unique(right.getCurrentDataStream(), actions_dag); - before_join_step->setStepDescription("Before JOIN RIGHT"); - steps.emplace_back(before_join_step.get()); - right.addStep(std::move(before_join_step)); - } - } - else if (table_sides.size() == 2) - { - if (!allow_mixed_condition) - return false; - auto mixed_join_expressions_actions = expressionsToActionsDAG({expr}, mixed_header); - table_join.getMixedJoinExpression() - = std::make_shared(mixed_join_expressions_actions, ExpressionActionsSettings::fromContext(context)); - } - else - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not any table column is used in the join condition.\n{}", join_rel.DebugString()); - } - return true; -} - void CrossRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right) { /// If the columns name in right table is duplicated with left table, we need to rename the right table's columns. @@ -501,68 +296,6 @@ void CrossRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left } } -/// Join keys are collected from substrait::JoinRel::expression() which only contains the equal join conditions. -void CrossRelParser::collectJoinKeys( - TableJoin & table_join, const substrait::CrossRel & join_rel, const DB::Block & left_header, const DB::Block & right_header) -{ - if (!join_rel.has_expression()) - return; - const auto & expr = join_rel.expression(); - auto & join_clause = table_join.getClauses().back(); - std::list expressions_stack; - expressions_stack.push_back(&expr); - while (!expressions_stack.empty()) - { - /// Must handle the expressions in DF order. It matters in sort merge join. - const auto * current_expr = expressions_stack.back(); - expressions_stack.pop_back(); - if (!current_expr->has_scalar_function()) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Function expression is expected"); - auto function_name = parseFunctionName(current_expr->scalar_function().function_reference(), current_expr->scalar_function()); - if (!function_name) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid function expression"); - if (*function_name == "equals") - { - String left_key, right_key; - size_t left_pos = 0, right_pos = 0; - for (const auto & arg : current_expr->scalar_function().arguments()) - { - if (!arg.value().has_selection() || !arg.value().selection().has_direct_reference() - || !arg.value().selection().direct_reference().has_struct_field()) - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A column reference is expected"); - } - auto col_pos_ref = arg.value().selection().direct_reference().struct_field().field(); - if (col_pos_ref < left_header.columns()) - { - left_pos = col_pos_ref; - left_key = left_header.getByPosition(col_pos_ref).name; - } - else - { - right_pos = col_pos_ref - left_header.columns(); - right_key = right_header.getByPosition(col_pos_ref - left_header.columns()).name; - } - } - if (left_key.empty() || right_key.empty()) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Invalid key equal join condition"); - join_clause.addKey(left_key, right_key, false); - } - else if (*function_name == "and") - { - expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(1).value()); - expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(0).value()); - } - else if (*function_name == "not") - { - expressions_stack.push_back(¤t_expr->scalar_function().arguments().at(0).value()); - } - else - { - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unknow function: {}", *function_name); - } - } -} void registerCrossRelParser(RelParserFactory & factory) { diff --git a/cpp-ch/local-engine/Parser/CrossRelParser.h b/cpp-ch/local-engine/Parser/CrossRelParser.h index 9766b4e91d242..f1cd60385e26a 100644 --- a/cpp-ch/local-engine/Parser/CrossRelParser.h +++ b/cpp-ch/local-engine/Parser/CrossRelParser.h @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include @@ -31,7 +30,6 @@ namespace local_engine class StorageJoinFromReadBuffer; -std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); class CrossRelParser : public RelParser { @@ -56,12 +54,8 @@ class CrossRelParser : public RelParser void renamePlanColumns(DB::QueryPlan & left, DB::QueryPlan & right, const StorageJoinFromReadBuffer & storage_join); void addConvertStep(TableJoin & table_join, DB::QueryPlan & left, DB::QueryPlan & right); void addPostFilter(DB::QueryPlan & query_plan, const substrait::CrossRel & join); - void collectJoinKeys( - TableJoin & table_join, const substrait::CrossRel & join_rel, const DB::Block & left_header, const DB::Block & right_header); bool applyJoinFilter( DB::TableJoin & table_join, const substrait::CrossRel & join_rel, DB::QueryPlan & left, DB::QueryPlan & right, bool allow_mixed_condition); - static std::unordered_set extractTableSidesFromExpression( - const substrait::Expression & expr, const DB::Block & left_header, const DB::Block & right_header); }; } diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index 9a3cc91baaa99..2a37bcd24b2f2 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -15,6 +15,7 @@ * limitations under the License. */ #include "JoinRelParser.h" + #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include #include @@ -98,51 +100,15 @@ JoinOptimizationInfo parseJoinOptimizationInfo(const substrait::JoinRel & join) return info; } - -void reorderJoinOutput(DB::QueryPlan & plan, DB::Names cols) -{ - ActionsDAGPtr project = std::make_shared(plan.getCurrentDataStream().header.getNamesAndTypesList()); - NamesWithAliases project_cols; - for (const auto & col : cols) - { - project_cols.emplace_back(NameWithAlias(col, col)); - } - project->project(project_cols); - QueryPlanStepPtr project_step = std::make_unique(plan.getCurrentDataStream(), project); - project_step->setStepDescription("Reorder Join Output"); - plan.addStep(std::move(project_step)); -} - namespace local_engine { - -std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type) -{ - switch (join_type) - { - case substrait::JoinRel_JoinType_JOIN_TYPE_INNER: - return {DB::JoinKind::Inner, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT_SEMI: - return {DB::JoinKind::Left, DB::JoinStrictness::Semi}; - case substrait::JoinRel_JoinType_JOIN_TYPE_ANTI: - return {DB::JoinKind::Left, DB::JoinStrictness::Anti}; - case substrait::JoinRel_JoinType_JOIN_TYPE_LEFT: - return {DB::JoinKind::Left, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_RIGHT: - return {DB::JoinKind::Right, DB::JoinStrictness::All}; - case substrait::JoinRel_JoinType_JOIN_TYPE_OUTER: - return {DB::JoinKind::Full, DB::JoinStrictness::All}; - default: - throw Exception(ErrorCodes::UNKNOWN_TYPE, "unsupported join type {}.", magic_enum::enum_name(join_type)); - } -} std::shared_ptr createDefaultTableJoin(substrait::JoinRel_JoinType join_type) { auto & global_context = SerializedPlanParser::global_context; auto table_join = std::make_shared( global_context->getSettings(), global_context->getGlobalTemporaryVolume(), global_context->getTempDataOnDisk()); - std::pair kind_and_strictness = getJoinKindAndStrictness(join_type); + std::pair kind_and_strictness = JoinUtil::getJoinKindAndStrictness(join_type); table_join->setKind(kind_and_strictness.first); table_join->setStrictness(kind_and_strictness.second); return table_join; @@ -436,7 +402,7 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q query_plan = std::make_unique(); query_plan->unitePlans(std::move(join_step), {std::move(plans)}); } - reorderJoinOutput(*query_plan, after_join_names); + JoinUtil::reorderJoinOutput(*query_plan, after_join_names); return query_plan; } diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.h b/cpp-ch/local-engine/Parser/JoinRelParser.h index c423f43908e70..15468b54b6f49 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.h +++ b/cpp-ch/local-engine/Parser/JoinRelParser.h @@ -31,8 +31,6 @@ namespace local_engine class StorageJoinFromReadBuffer; -std::pair getJoinKindAndStrictness(substrait::JoinRel_JoinType join_type); - class JoinRelParser : public RelParser { public: diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 2338bfe8b1e6e..3a8791f531c0e 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -1124,13 +1124,12 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_StorageJoinBuilder_nativeBuild const auto named_struct_a = local_engine::getByteArrayElementsSafe(env, named_struct); const std::string::size_type struct_size = named_struct_a.length(); std::string struct_string{reinterpret_cast(named_struct_a.elems()), struct_size}; - const auto join_type = static_cast(join_type_); const jsize length = env->GetArrayLength(in); local_engine::ReadBufferFromByteArray read_buffer_from_java_array(in, length); DB::CompressedReadBuffer input(read_buffer_from_java_array); local_engine::configureCompressedReadBuffer(input); const auto * obj = make_wrapper(local_engine::BroadCastJoinBuilder::buildJoin( - hash_table_id, input, row_count_, join_key, join_type, has_mixed_join_condition, struct_string)); + hash_table_id, input, row_count_, join_key, join_type_, has_mixed_join_condition, struct_string)); return obj->instance(); LOCAL_ENGINE_JNI_METHOD_END(env, 0) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala index 8ddcc7b7f93e1..b6bccc480601f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/BackendSettingsApi.scala @@ -144,7 +144,10 @@ trait BackendSettingsApi { def supportCartesianProductExec(): Boolean = false - def supportBroadcastNestedLoopJoinExec(): Boolean = true + def supportBroadcastNestedJoinJoinType: JoinType => Boolean = { + case _: InnerLike | LeftOuter | RightOuter => true + case _ => false + } def supportSampleExec(): Boolean = false diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala index 241682ad867a0..dd7968701e1c3 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BroadcastNestedLoopJoinExecTransformer.scala @@ -24,7 +24,7 @@ import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} -import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftOuter, LeftSemi, RightOuter} +import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftExistence, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.joins.BaseJoinExec @@ -80,7 +80,7 @@ abstract class BroadcastNestedLoopJoinExecTransformer( left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output - case LeftSemi => // LeftExistence(_) + case LeftExistence(_) => left.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) @@ -160,8 +160,9 @@ abstract class BroadcastNestedLoopJoinExecTransformer( } override protected def doValidateInternal(): ValidationResult = { - if (!BackendsApiManager.getSettings.supportBroadcastNestedLoopJoinExec()) { - return ValidationResult.notOk("Broadcast Nested Loop join is not supported in this backend") + if (!BackendsApiManager.getSettings.supportBroadcastNestedJoinJoinType(joinType)) { + return ValidationResult.notOk( + s"Broadcast Nested Loop join is not supported join type $joinType in this backend") } if (substraitJoinType == CrossRel.JoinType.UNRECOGNIZED) { return ValidationResult.notOk(s"$joinType join is not supported with BroadcastNestedLoopJoin") diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala index 56b63ef8457ad..637085743b389 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/validator/Validators.scala @@ -147,8 +147,6 @@ object Validators { case p: SortAggregateExec if !settings.replaceSortAggWithHashAgg => fail(p) case p: CartesianProductExec if !settings.supportCartesianProductExec() => fail(p) - case p: BroadcastNestedLoopJoinExec if !settings.supportBroadcastNestedLoopJoinExec() => - fail(p) case p: TakeOrderedAndProjectExec if !settings.supportColumnarShuffleExec() => fail(p) case _ => pass() } diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala index 9671c7a6bca2c..e8e7ce06feaf4 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala @@ -48,6 +48,10 @@ object SubstraitUtil { // the left and right relations are exchanged and the // join type is reverted. CrossRel.JoinType.JOIN_TYPE_LEFT + case LeftSemi => + CrossRel.JoinType.JOIN_TYPE_LEFT_SEMI + case FullOuter => + CrossRel.JoinType.JOIN_TYPE_OUTER case _ => CrossRel.JoinType.UNRECOGNIZED } diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala index df1c87cb0ccc4..4da7a2f6f11ae 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/ColumnarBroadcastExchangeExec.scala @@ -27,7 +27,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.Statistics -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, IdentityBroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, BroadcastExchangeLike} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf @@ -134,13 +134,6 @@ case class ColumnarBroadcastExchangeExec(mode: BroadcastMode, child: SparkPlan) } override protected def doValidateInternal(): ValidationResult = { - // CH backend does not support IdentityBroadcastMode used in BNLJ - if ( - mode == IdentityBroadcastMode && !BackendsApiManager.getSettings - .supportBroadcastNestedLoopJoinExec() - ) { - return ValidationResult.notOk("This backend does not support IdentityBroadcastMode and BNLJ") - } BackendsApiManager.getValidatorApiInstance .doSchemaValidate(schema) .map {